-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_environment.py
More file actions
80 lines (69 loc) · 3.34 KB
/
test_environment.py
File metadata and controls
80 lines (69 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from env.environment import SQLDebuggerEnvironment
from env.models import Action, ActionType
env = SQLDebuggerEnvironment()
# Test 1: state() before reset — must not crash
s = env.state()
print(f"State before reset: initialized={s.initialized}")
# Test 2: reset()
obs = env.reset(difficulty="easy")
print(f"Reset OK: task_id={obs.task_id}, difficulty={obs.difficulty}")
print(f"Context keys: {list(obs.current_context.keys())}")
print(f"Ground truth NOT in context: {'fixed_query' not in obs.current_context}")
# Test 3: step() identify_error
action1 = Action(
action_type=ActionType.IDENTIFY_ERROR,
payload={"error_location": "SELECT clause", "error_type": "syntax", "explanation": "Missing commas"}
)
resp1 = env.step(action1)
print(f"Step 1: reward={resp1.reward.score}, done={resp1.done}, step={resp1.observation.step_count}")
# Test 4: step() request_hint
action2 = Action(action_type=ActionType.REQUEST_HINT, payload={"hint_type": "location"})
resp2 = env.step(action2)
print(f"Step 2 hint: reward={resp2.reward.score}, hints_used={resp2.observation.hints_used}")
print(f"Hint in context: {'last_hint' in resp2.observation.current_context}")
# Test 5: step() submit_answer
obs = env.reset(difficulty="easy", task_id="easy_001")
action3 = Action(
action_type=ActionType.SUBMIT_ANSWER,
payload={
"fixed_query": "SELECT id, name, email FROM users WHERE active = 1",
"explanation": "Added missing commas between column names in SELECT clause",
"error_type": "syntax",
"error_location": "SELECT clause",
"confidence": 0.95
}
)
resp3 = env.step(action3)
print(f"Submit answer: reward={resp3.reward.score}, done={resp3.done}")
# Test 6: step after done — must not crash
resp4 = env.step(action3)
print(f"Step after done: done={resp4.done}, feedback='{resp4.reward.feedback}'")
# Test 7: null action
obs = env.reset(difficulty="easy")
resp5 = env.step(None)
print(f"Null action: reward={resp5.reward.score}, done={resp5.done}")
# Test 8: reset mid-episode clears state
obs = env.reset(difficulty="medium")
print(f"Mid-episode reset: new task={obs.task_id}, step_count={obs.step_count}")
# Test 9: full episode 10 steps
obs = env.reset(difficulty="hard")
print(f"Hard episode started: {obs.task_id}")
actions = [
Action(action_type=ActionType.IDENTIFY_ERROR, payload={"error_location": "SELECT clause", "error_type": "performance"}),
Action(action_type=ActionType.EXPLAIN_ISSUE, payload={"explanation": "N+1 correlated subqueries cause multiple DB hits per row", "impact": "O(n) queries", "root_cause": "Subquery per user"}),
Action(action_type=ActionType.OPTIMIZE_QUERY, payload={
"optimized_query": "SELECT u.id, u.name, COUNT(o.id) as order_count, COALESCE(SUM(o.total), 0) as total_spent FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.id, u.name",
"optimization_type": "Replace N+1 correlated subqueries with LEFT JOIN aggregation",
"explanation": "Single query replaces N+1 pattern",
"root_cause": "Correlated subqueries in SELECT",
"expected_improvement": "99% reduction in DB round trips",
"confidence": 0.9
}),
]
total = 0.0
for i, a in enumerate(actions):
r = env.step(a)
total += r.reward.score
print(f" Hard step {i+1}: reward={r.reward.score}, done={r.done}")
print(f"Hard episode total reward: {round(total,4)}")
print("environment.py OK")