|
2 | 2 | import os |
3 | 3 | import re |
4 | 4 | import sys |
5 | | -from typing import Any, Dict |
| 5 | +from typing import Any, Dict, List, Optional |
6 | 6 |
|
7 | 7 | import requests |
8 | 8 | from openai import OpenAI |
9 | 9 |
|
10 | | -API_BASE_URL = os.getenv("API_BASE_URL") |
11 | | -MODEL_NAME = os.getenv("MODEL_NAME") |
| 10 | +API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") |
| 11 | +MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") |
12 | 12 | HF_TOKEN = os.getenv("HF_TOKEN") |
13 | 13 | ENV_URL = os.getenv("ENV_URL", "http://localhost:8000") |
14 | 14 |
|
| 15 | +TASK_NAME = os.getenv("TASK_NAME", "shopops") |
| 16 | +BENCHMARK = os.getenv("BENCHMARK", "shopops") |
| 17 | +MAX_STEPS = int(os.getenv("MAX_STEPS", "20")) |
| 18 | + |
15 | 19 | REQUIRED_VARS = { |
16 | 20 | "API_BASE_URL": API_BASE_URL, |
17 | 21 | "MODEL_NAME": MODEL_NAME, |
@@ -45,70 +49,111 @@ def _safe_action() -> Dict[str, Any]: |
45 | 49 | } |
46 | 50 |
|
47 | 51 |
|
| 52 | +def _log_start(task: str, env: str, model: str) -> None: |
| 53 | + print(f"[START] task={task} env={env} model={model}", flush=True) |
| 54 | + |
| 55 | + |
| 56 | +def _log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: |
| 57 | + done_val = str(done).lower() |
| 58 | + error_val = error if error else "null" |
| 59 | + print( |
| 60 | + f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", |
| 61 | + flush=True, |
| 62 | + ) |
| 63 | + |
| 64 | + |
| 65 | +def _log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: |
| 66 | + rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| 67 | + print( |
| 68 | + f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", |
| 69 | + flush=True, |
| 70 | + ) |
| 71 | + |
| 72 | + |
48 | 73 | def main() -> None: |
49 | 74 | _require_env() |
50 | 75 | client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) |
51 | 76 |
|
52 | 77 | seed = int(os.getenv("SEED", "42")) |
| 78 | + _log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME) |
| 79 | + |
| 80 | + rewards: List[float] = [] |
| 81 | + steps_taken = 0 |
| 82 | + success = False |
| 83 | + score = 0.0 |
53 | 84 |
|
54 | | - print("[START]") |
55 | | - print(f"episode_id=unknown") |
56 | | - print(f"seed={seed}") |
57 | | - print(f"model={MODEL_NAME}") |
58 | | - print(f"env_url={ENV_URL}") |
59 | | - |
60 | | - reset_resp = requests.post(f"{ENV_URL}/reset", json={"seed": seed}) |
61 | | - reset_resp.raise_for_status() |
62 | | - payload = reset_resp.json() |
63 | | - obs = payload["observation"] |
64 | | - episode_id = obs.get("episode_id", "unknown") |
65 | | - print(f"episode_id={episode_id}") |
66 | | - |
67 | | - step = 0 |
68 | | - done = payload.get("done", False) |
69 | | - |
70 | | - while not done: |
71 | | - prompt = ( |
72 | | - "You are an e-commerce ops agent. Return ONLY JSON with keys: " |
73 | | - "action_type, refund_amount_usd, replacement_expedite, escalation_reason. " |
74 | | - f"Observation: {json.dumps(obs)}" |
75 | | - ) |
76 | | - |
77 | | - try: |
78 | | - response = client.responses.create( |
79 | | - model=MODEL_NAME, |
80 | | - input=prompt, |
| 85 | + try: |
| 86 | + reset_resp = requests.post(f"{ENV_URL}/reset", json={"seed": seed}) |
| 87 | + reset_resp.raise_for_status() |
| 88 | + payload = reset_resp.json() |
| 89 | + obs = payload["observation"] |
| 90 | + episode_id = obs.get("episode_id", "unknown") |
| 91 | + |
| 92 | + step = 1 |
| 93 | + done = payload.get("done", False) |
| 94 | + |
| 95 | + while not done and step <= MAX_STEPS: |
| 96 | + prompt = ( |
| 97 | + "You are an e-commerce ops agent. Return ONLY JSON with keys: " |
| 98 | + "action_type, refund_amount_usd, replacement_expedite, escalation_reason. " |
| 99 | + f"Observation: {json.dumps(obs)}" |
81 | 100 | ) |
82 | | - action = _parse_action(response.output_text) |
83 | | - except Exception: |
84 | | - action = _safe_action() |
85 | | - |
86 | | - step_resp = requests.post( |
87 | | - f"{ENV_URL}/step", |
88 | | - json={"action": action, "episode_id": episode_id}, |
89 | | - ) |
90 | | - step_resp.raise_for_status() |
91 | | - step_payload = step_resp.json() |
92 | | - |
93 | | - print("[STEP]") |
94 | | - print(f"step={step}") |
95 | | - print(f"action={json.dumps(action)}") |
96 | | - print(f"reward={step_payload.get('reward')}") |
97 | | - print(f"done={step_payload.get('done')}") |
98 | | - |
99 | | - obs = step_payload["observation"] |
100 | | - done = step_payload.get("done", False) |
101 | | - step += 1 |
102 | | - if step >= 20: |
103 | | - break |
104 | | - |
105 | | - final_score = ( |
106 | | - obs.get("metadata", {}) |
107 | | - .get("episode_summary", {}) |
108 | | - .get("final_score") |
109 | | - ) |
110 | | - print("[END]") |
111 | | - print(f"final_score={final_score}") |
| 101 | + |
| 102 | + try: |
| 103 | + response = client.responses.create( |
| 104 | + model=MODEL_NAME, |
| 105 | + input=prompt, |
| 106 | + text={"format": {"type": "json_object"}}, |
| 107 | + ) |
| 108 | + action = _parse_action(response.output_text) |
| 109 | + except Exception as exc: |
| 110 | + action = _safe_action() |
| 111 | + |
| 112 | + step_resp = requests.post( |
| 113 | + f"{ENV_URL}/step", |
| 114 | + json={"action": action, "episode_id": episode_id}, |
| 115 | + ) |
| 116 | + step_payload = {} |
| 117 | + if step_resp.status_code == 200: |
| 118 | + step_payload = step_resp.json() |
| 119 | + reward = float(step_payload.get("reward") or 0.0) |
| 120 | + done = bool(step_payload.get("done", False)) |
| 121 | + error = ( |
| 122 | + (step_payload.get("observation") or {}) |
| 123 | + .get("metadata", {}) |
| 124 | + .get("validation_error") |
| 125 | + ) |
| 126 | + else: |
| 127 | + try: |
| 128 | + err_payload = step_resp.json() |
| 129 | + error = err_payload.get("detail") or str(err_payload) |
| 130 | + except Exception: |
| 131 | + error = step_resp.text or f"http_{step_resp.status_code}" |
| 132 | + reward = 0.0 |
| 133 | + done = True |
| 134 | + |
| 135 | + rewards.append(reward) |
| 136 | + steps_taken = step |
| 137 | + |
| 138 | + _log_step( |
| 139 | + step=step, |
| 140 | + action=json.dumps(action, separators=(",", ":")), |
| 141 | + reward=reward, |
| 142 | + done=done, |
| 143 | + error=error, |
| 144 | + ) |
| 145 | + |
| 146 | + if step_payload: |
| 147 | + obs = step_payload["observation"] |
| 148 | + step += 1 |
| 149 | + |
| 150 | + # HTTP API does not include episode_summary, so compute a normalized score. |
| 151 | + # This keeps score within [0, 1] for logging. |
| 152 | + score = sum(rewards) / float(MAX_STEPS) if MAX_STEPS > 0 else 0.0 |
| 153 | + score = max(0.0, min(1.0, score)) |
| 154 | + success = score > 0.0 |
| 155 | + finally: |
| 156 | + _log_end(success=success, steps=steps_taken, score=score, rewards=rewards) |
112 | 157 |
|
113 | 158 |
|
114 | 159 | if __name__ == "__main__": |
|
0 commit comments