-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
346 lines (281 loc) · 13.7 KB
/
inference.py
File metadata and controls
346 lines (281 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
inference.py — GovScheme-Env Baseline Inference Script
=======================================================
MANDATORY — placed at root of project as required by hackathon rules.
Reads from environment variables:
API_BASE_URL — LLM API endpoint (default: HuggingFace router)
MODEL_NAME — model to use (default: Qwen/Qwen2.5-72B-Instruct)
HF_TOKEN — your HF API key
ENV_URL — where the env server is running (default: localhost:7860)
Stdout format (exactly as required by hackathon evaluator):
[START] task=<task_name> env=govscheme-env model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...>
Run:
python inference.py
"""
import json
import os
import re
import sys
import time
import textwrap
from typing import Any, Dict, List, Optional
import requests
from openai import OpenAI
# ── Config ────────────────────────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
MAX_STEPS = 3 # max steps per task (matches env setting)
TEMPERATURE = 0.2 # low = more deterministic = more reproducible
MAX_TOKENS = 1500
SUCCESS_THRESHOLD = 0.5 # score >= 0.5 counts as success
# Fixed citizen + seed for reproducible baseline scores
TASK_CONFIG = {
"scheme_identification": {"citizen_id": "CIT_001", "seed": 42}, # easy: rural farmer UP
"scheme_ranking": {"citizen_id": "CIT_006", "seed": 42}, # medium: weaver WB
"form_filling": {"citizen_id": "CIT_009", "seed": 42}, # hard: SC farmer Odisha
}
TASKS = list(TASK_CONFIG.keys())
# ── Mandatory stdout loggers ───────────────────────────────────────────────
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
safe_action = action.replace("\n", " ")[:200]
print(f"[STEP] step={step} action={safe_action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ── System prompts per task ────────────────────────────────────────────────
SYSTEM_PROMPTS = {
"scheme_identification": textwrap.dedent("""
You are a government welfare officer helping Indian citizens access welfare schemes.
Given a citizen profile and a list of 18 Indian government schemes, identify ALL schemes
this citizen is eligible for.
Check these eligibility factors carefully:
- Age limits (age_min, age_max)
- Income limits (annual_income_inr, annual_family_income_inr)
- Caste (SC, ST, OBC, General)
- Occupation (farmer, student, weaver, etc.)
- Gender
- Land ownership (acres)
- Has Aadhaar, has bank account
- Rural vs urban
- House type (kachha qualifies for housing schemes)
- NOT a government employee, NOT an income taxpayer
Respond ONLY with valid JSON, no markdown, no explanation outside the JSON:
{
"action_type": "identify_schemes",
"scheme_ids": ["SCHEME_ID_1", "SCHEME_ID_2"],
"reasoning": "Brief explanation of why each scheme was included"
}
Use only scheme IDs from the available_schemes list.
When in doubt, include the scheme — missing one hurts more than including an extra.
""").strip(),
"scheme_ranking": textwrap.dedent("""
You are a welfare advisor ranking government schemes by benefit value for a citizen.
Rank ALL eligible schemes from most to least beneficial based on annual_benefit_inr.
Schemes with benefit_inr=0 (credit/loan access schemes) go at the bottom.
Respond ONLY with valid JSON, no markdown, no explanation outside the JSON:
{
"action_type": "rank_schemes",
"ranked_schemes": [
{
"scheme_id": "AYUSHMAN_BHARAT",
"rank": 1,
"reason": "Highest value: ₹5,00,000 health insurance. Citizen income ₹72,000 qualifies.",
"benefit_inr": 500000
},
...
],
"reasoning": "Overall ranking rationale"
}
Your "reason" for each scheme MUST mention:
- The actual INR benefit amount
- Why this citizen qualifies (caste, income, occupation, etc.)
""").strip(),
"form_filling": textwrap.dedent("""
You are a government application assistant filling forms for citizens.
Fill the application form for the target scheme using ONLY data from the citizen profile.
DO NOT invent, guess, or hallucinate any values.
STRICT format rules — get these exactly right:
- Aadhaar number: 12 digits, must start with 2-9 (e.g. "234567890123")
- IFSC code: 4 uppercase letters + 0 + 6 alphanumeric (e.g. "SBIN0001234")
- Mobile number: 10 digits, must start with 6, 7, 8, or 9 (e.g. "9876543210")
- Date of birth: DD/MM/YYYY format (e.g. "15/03/1986")
- category: use the caste field value (SC / ST / OBC / General)
- gender: use exactly "Male", "Female", or "Other"
If a field value is not in the citizen profile, leave it out entirely — do not guess.
Respond ONLY with valid JSON, no markdown:
{
"action_type": "fill_form",
"form_data": {
"applicant_name": "...",
"aadhaar_number": "...",
...
},
"reasoning": "Explanation of how you mapped each field"
}
""").strip(),
}
# ── LLM call ──────────────────────────────────────────────────────────────
def call_llm(client: OpenAI, task: str, obs: Dict) -> Dict:
"""Call the LLM with the current observation and return parsed JSON action."""
user_prompt = _build_prompt(task, obs)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPTS[task]},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
)
raw = (completion.choices[0].message.content or "").strip()
# Strip markdown fences if model wraps response in ```json ... ```
raw = re.sub(r"```(?:json)?\s*", "", raw).strip().rstrip("`").strip()
return json.loads(raw)
except json.JSONDecodeError as e:
print(f"[DEBUG] JSON parse error: {e}", flush=True)
return {"action_type": "identify_schemes", "scheme_ids": [], "reasoning": "parse_error"}
except Exception as e:
print(f"[DEBUG] LLM error: {e}", flush=True)
return {"action_type": "identify_schemes", "scheme_ids": [], "reasoning": "llm_error"}
def _build_prompt(task: str, obs: Dict) -> str:
"""Build a clear user prompt from the current observation."""
citizen = obs.get("citizen_profile", {})
available = obs.get("available_schemes", [])
identified = obs.get("identified_schemes")
form_template = obs.get("form_template")
target_scheme = obs.get("target_scheme_id")
lines = [
f"TASK: {obs.get('task_description', '')}",
"",
"=== CITIZEN PROFILE ===",
json.dumps(citizen, indent=2),
"",
]
if task == "scheme_identification":
lines += [
"=== AVAILABLE SCHEMES (use only these scheme IDs) ===",
json.dumps(
[{"scheme_id": s["scheme_id"], "name": s["name"],
"benefit_type": s["benefit_type"], "annual_benefit_inr": s["annual_benefit_inr"]}
for s in available],
indent=2
),
]
elif task == "scheme_ranking":
eligible = [s for s in available if identified and s["scheme_id"] in identified]
lines += [
"=== ELIGIBLE SCHEMES TO RANK (rank all of these) ===",
json.dumps(
[{"scheme_id": s["scheme_id"], "name": s["name"],
"annual_benefit_inr": s["annual_benefit_inr"],
"benefit_description": s["benefit_description"]}
for s in eligible],
indent=2
),
]
elif task == "form_filling":
lines += [
f"=== TARGET SCHEME: {target_scheme} ===",
"",
"=== FORM FIELDS TO FILL (field name → type and constraints) ===",
json.dumps(form_template, indent=2),
]
return "\n".join(lines)
# ── Env HTTP calls ─────────────────────────────────────────────────────────
def env_reset(task_name: str, citizen_id: str, seed: int = 42) -> Dict:
r = requests.post(f"{ENV_URL}/reset",
json={"task_name": task_name, "citizen_id": citizen_id, "seed": seed},
timeout=30)
r.raise_for_status()
return r.json()
def env_step(action: Dict) -> Dict:
r = requests.post(f"{ENV_URL}/step", json=action, timeout=30)
r.raise_for_status()
return r.json()
# ── Task runner ────────────────────────────────────────────────────────────
def run_task(llm: OpenAI, task_name: str) -> Dict[str, Any]:
"""Run one complete task episode. Returns score, rewards, steps, success."""
cfg = TASK_CONFIG[task_name]
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
error_msg = None
log_start(task=task_name, env="govscheme-env", model=MODEL_NAME)
try:
data = env_reset(task_name, cfg["citizen_id"], cfg["seed"])
obs = data.get("observation", data)
for step in range(1, MAX_STEPS + 1):
# Get action from LLM
action_dict = call_llm(llm, task_name, obs)
# Send action to environment
try:
result = env_step(action_dict)
except Exception as e:
error_msg = str(e)
log_step(step, str(action_dict)[:100], 0.0, True, error_msg)
break
reward = float(result.get("reward", 0.0))
done = bool(result.get("done", False))
error_msg = result.get("info", {}).get("error")
rewards.append(reward)
steps_taken = step
log_step(
step=step,
action=json.dumps(action_dict)[:150],
reward=reward,
done=done,
error=error_msg,
)
obs = result.get("observation", obs)
if done:
break
# Score = best single-step score (agent gets 3 attempts)
score = max(rewards) if rewards else 0.0
score = round(min(max(score, 0.0), 1.0), 3)
success = score >= SUCCESS_THRESHOLD
except Exception as e:
print(f"[DEBUG] Task {task_name} failed: {e}", flush=True)
if not rewards:
rewards = [0.0]
steps_taken = max(steps_taken, 1)
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return {"task": task_name, "score": score, "rewards": rewards, "steps": steps_taken, "success": success}
# ── Main ───────────────────────────────────────────────────────────────────
def main():
print(f"[DEBUG] ENV_URL = {ENV_URL}", flush=True)
print(f"[DEBUG] MODEL = {MODEL_NAME}", flush=True)
print(f"[DEBUG] API_BASE = {API_BASE_URL}", flush=True)
# Health check
try:
h = requests.get(f"{ENV_URL}/health", timeout=10)
print(f"[DEBUG] Env health: {h.json()}", flush=True)
except Exception as e:
print(f"[DEBUG] WARNING — env health check failed: {e}", flush=True)
print("[DEBUG] Make sure the server is running before inference.py", flush=True)
llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
results = []
for task_name in TASKS:
print(f"\n[DEBUG] ═══ Running task: {task_name} ═══", flush=True)
result = run_task(llm, task_name)
results.append(result)
time.sleep(1) # rate-limit politeness
# Final summary
print("\n[DEBUG] ═══ FINAL SCORES ═══", flush=True)
avg = sum(r["score"] for r in results) / len(results)
for r in results:
status = "✓" if r["success"] else "✗"
print(f"[DEBUG] {status} {r['task']}: score={r['score']:.3f}", flush=True)
print(f"[DEBUG] Average score: {avg:.3f}", flush=True)
if __name__ == "__main__":
main()