Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions Resources/python/schola/rllib/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,26 +189,50 @@ def _build_spaces(self, obs_defns: Dict, action_defns: Dict, first_env_id: int):
def _validate_environments(self, ids: List[List[str]]):
"""
Validate that environments and agents are properly configured.

Args:
ids: List of agent ID lists (one per environment)

Raises:
NoEnvironmentsException: If no environments provided
NoAgentsException: If any environment has no agents
"""
try:
if len(ids) == 0:
raise NoEnvironmentsException()

for env_id, agent_id_list in enumerate(ids):
if len(agent_id_list) == 0:
raise NoAgentsException(env_id)

except Exception as e:
self.protocol.close()
self.simulator.stop()
raise e

@staticmethod
def _filter_dead_agents(env_id, already_done, observations, rewards, terminateds, truncateds, infos):
"""
Remove already-dead agents from all five gRPC return dicts for one env slot.

The gRPC response unconditionally includes every agent's state, even agents
whose terminal flag was preserved by TScholaEnvironment::Step(). RLlib closes
an agent's episode on the step it first receives terminated/truncated=True and
raises MultiAgentEnvError if it sees any further data for that agent. This
helper prevents that by dropping the stale entries before they reach RLlib.

Args:
env_id: Key used to index each dict (int for RayVecEnv, self._env_id for RayEnv).
already_done: Set of agent IDs that were terminal before this step.
observations, rewards, terminateds, truncateds, infos: Protocol return dicts,
modified in-place.
"""
for agent_id in already_done:
observations[env_id].pop(agent_id, None)
rewards[env_id].pop(agent_id, None)
terminateds[env_id].pop(agent_id, None)
truncateds[env_id].pop(agent_id, None)
infos[env_id].pop(agent_id, None)

def close_extras(self, **kwargs):
"""Close protocol and stop simulator."""
Expand Down Expand Up @@ -395,40 +419,52 @@ def step(
"""
# Convert actions to dict format expected by protocol (env_id: actions)
action_dict = {self._env_id: actions}


# Agents already dead before this step — C++ restores their terminal state
# in OutAgentStates, so the gRPC response still includes their entries.
# We must not forward those entries to RLlib, which closes an agent's
# episode on the step it first receives terminated=True and will crash if
# it sees any further data for that agent.
already_done = self._terminated_agents | self._truncated_agents

# Send action and get response with no autoreset support
observations, rewards, terminateds, truncateds, infos, _, _ = \
self.protocol.send_action_msg(action_dict, self._single_action_spaces)


# Strip previously-dead agents from every return dict so RLlib never
# receives a second observation for an agent whose episode is closed.
eid = self._env_id
self._filter_dead_agents(eid, already_done, observations, rewards, terminateds, truncateds, infos)

# Normal step - update agent tracking
agents_in_terminateds = set(terminateds[self._env_id].keys())
agents_in_truncateds = set(truncateds[self._env_id].keys())
all_agents_this_step = agents_in_terminateds | agents_in_truncateds

# Track terminated/truncated agents
for agent_id in all_agents_this_step:
if agent_id in terminateds[self._env_id] and terminateds[self._env_id][agent_id]:
self._terminated_agents.add(agent_id)
if agent_id in truncateds[self._env_id] and truncateds[self._env_id][agent_id]:
self._truncated_agents.add(agent_id)

# Update current agents (remove terminated/truncated)
current_active_agents = set()
for agent_id in all_agents_this_step:
is_terminated = agent_id in terminateds[self._env_id] and terminateds[self._env_id][agent_id]
is_truncated = agent_id in truncateds[self._env_id] and truncateds[self._env_id][agent_id]
if not (is_terminated or is_truncated):
current_active_agents.add(agent_id)

self._current_agents = current_active_agents
# Update agents attribute to match current active agents
self.agents = list(current_active_agents) if current_active_agents else list(self.possible_agents)

# Compute __all__ flag
agents_in_this_env = self._current_agents | self._terminated_agents | self._truncated_agents
num_done = len(self._terminated_agents | self._truncated_agents)
num_total = len(agents_in_this_env)

terminateds[self._env_id]["__all__"] = (num_done == num_total) if num_total > 0 else False
truncateds[self._env_id]["__all__"] = (len(self._truncated_agents) == num_total) if num_total > 0 else False

Expand Down Expand Up @@ -650,12 +686,16 @@ def step(self, actions: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Lis
# We are in Next Step reset mode so ignore the initial_obs and initial_infos
observations, rewards, terminateds, truncateds, infos, _, _ = self.protocol.send_action_msg(action_dict, self._single_action_spaces)

# Handle agents dynamically based on what Unreal returns
# Following RLlib spec: terminateds/truncateds dicts contain ALL agents (even inactive ones)
# In turn-based/hierarchical scenarios, agents may not act every step but are still alive
# and appear in terminateds/truncateds with False values
for env_id in range(self.num_envs):
env : _SingleEnvWrapper = self.envs[env_id]
# When _reset_on_next_step is True, the gRPC response already contains
# the new episode's initial observations — do not filter them with the
# dead-agent set from the just-finished episode.
if not env._reset_on_next_step:
# Capture before _step() updates tracking state.
already_done = env._terminated_agents | env._truncated_agents
# Strip dead agents from gRPC response before RLlib or _step() sees them.
self._filter_dead_agents(env_id, already_done, observations, rewards, terminateds, truncateds, infos)
env._step(observations[env_id], terminateds[env_id], truncateds[env_id])

agents_in_this_env = env._current_agents | env._terminated_agents | env._truncated_agents
Expand Down
44 changes: 42 additions & 2 deletions Source/ScholaTraining/Public/Environment/EnvironmentInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,52 @@ class SCHOLATRAINING_API TScholaEnvironment : public TScriptInterface<T>, public

/**
* @brief Execute a step through the Blueprint interface.
* @param[in] InActions Map of agent names to their actions.
*
* Handles staggered agent death: agents that were already terminated or truncated
* before this step have their terminal state preserved. Only actions for live agents
* are forwarded to the Blueprint (dead agents have no entry in InActions since the
* Python side only sends actions for agents RLlib is actively managing).
*
* This logic is centralised here so it automatically covers every reset protocol
* (Disabled, SameStep, NextStep) without duplication in AbstractGymConnector.
*
* @param[in] InActions Map of agent names to their actions (live agents only).
* @param[out] OutAgentStates Map of agent names to their resulting states.
*/
void Step(const TMap<FString, TInstancedStruct<FPoint>>& InActions, TMap<FString, FAgentState>& OutAgentStates) override
{
T::Execute_Step(this->GetObject(), InActions, OutAgentStates);
// Snapshot previously-dead agents before stepping.
TMap<FString, FAgentState> DeadAgentStates;
for (const auto& Pair : OutAgentStates)
{
if (Pair.Value.bTerminated || Pair.Value.bTruncated)
{
DeadAgentStates.Add(Pair.Key, Pair.Value);
}
}

// Build a filtered action map that excludes dead agents. Python only sends
// actions for live agents, but this guard also prevents any accidental
// dead-agent entry from reaching Execute_Step.
TMap<FString, TInstancedStruct<FPoint>> LiveActions;
for (const auto& ActionPair : InActions)
{
if (!DeadAgentStates.Contains(ActionPair.Key))
{
LiveActions.Add(ActionPair.Key, ActionPair.Value);
}
}

T::Execute_Step(this->GetObject(), LiveActions, OutAgentStates);

// Restore the full pre-step snapshot for previously-dead agents. A per-field
// patch would leave observations, info, and any future FAgentState members
// leaking stale Blueprint output; the snapshot is the source of truth for
// dead agents, so overwrite the entry verbatim.
for (const auto& DeadPair : DeadAgentStates)
{
OutAgentStates.Add(DeadPair.Key, DeadPair.Value);
}
};

};
Expand Down
206 changes: 206 additions & 0 deletions Test/integration/train_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) 2025 Advanced Micro Devices, Inc. All Rights Reserved.
"""
Integration test for the staggered-death fix.

Connects to a live Unreal Engine session (ScholaStaggeredTest project) and
runs RLlib PPO training. The test environment hard-codes agent deaths at
steps 5, 10, and 15, so every episode should be exactly 15 steps long.

Pass conditions
---------------
1. ep_len_mean ≈ 15 (range 13–17) for every completed iteration.
2. ep_len_mean does NOT grow across iterations (no freeze / accumulation).
3. Training completes without KeyError, NaN in rewards, or timeout.

Prerequisites
-------------
- Unreal Editor running ScholaStaggeredTest with the FIXED plugin variant.
(Run switch_to_fixed.bat, rebuild, then press Play in UE.)
- LogScholaTraining: Running Gym Connector visible in the UE Output Log.
- Python: pip install "ray[rllib]>=2.40" "gymnasium>=1.0" numpy torch
- Schola Python package installed: pip install -e Resources/python

Usage
-----
cd D:/Github/Schola
python Test/integration/train_test.py [--port 8500] [--iterations 3]
"""

import sys
import math
import argparse

GRPC_PORT = 8500
NUM_ITERATIONS = 3
EP_LEN_TARGET = 15
EP_LEN_TOLERANCE = 2 # acceptable: 13–17

AGENT_IDS = ["agent_0", "agent_1", "agent_2"]


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _get_ep_len(result: dict) -> float | None:
"""Extract ep_len_mean from an RLlib result dict (new and old API stacks)."""
candidates = [
result.get("env_runners", {}).get("episode_len_mean"),
result.get("sampler_results", {}).get("episode_len_mean"),
result.get("episode_len_mean"),
]
for v in candidates:
if v is not None and not math.isnan(v):
return float(v)
return None


def _get_ep_rew(result: dict) -> float | None:
"""Extract ep_rew_mean from an RLlib result dict."""
candidates = [
result.get("env_runners", {}).get("episode_reward_mean"),
result.get("sampler_results", {}).get("episode_reward_mean"),
result.get("episode_reward_mean"),
]
for v in candidates:
if v is not None and not math.isnan(v):
return float(v)
return None


# ---------------------------------------------------------------------------
# Main test
# ---------------------------------------------------------------------------

def run(port: int, num_iterations: int) -> bool:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.connectors.env_to_module import FlattenObservations
from ray.rllib.policy.policy import PolicySpec
from ray.tune.registry import register_env
from schola.core.protocols.protobuf.gRPC import gRPCProtocol
from schola.core.simulators.unreal.editor import UnrealEditor
from schola.rllib.env import RayEnv

print(f"\nConnecting to Unreal at localhost:{port} ...")

def make_env(*args, **kwargs):
simulator = UnrealEditor()
protocol = gRPCProtocol(url="localhost", port=port)
return RayEnv(protocol, simulator)

register_env("ScholaStaggeredDeath", make_env)

config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment(env="ScholaStaggeredDeath")
.framework("torch")
.env_runners(
num_env_runners=0,
env_to_module_connector=lambda env: FlattenObservations(
input_observation_space=env.single_observation_space,
input_action_space=env.single_action_space,
multi_agent=True,
),
)
.multi_agent(
policies={"shared_policy": PolicySpec()},
policy_mapping_fn=lambda agent_id, *args, **kwargs: "shared_policy",
)
.training(
train_batch_size=600, # small batch: ~40 episodes per iteration
)
)

algo = config.build_algo()
print("Algorithm built. Running training iterations ...\n")

ep_lens = []
checks_passed = 0
checks_total = 0

for i in range(num_iterations):
result = algo.train()

ep_len = _get_ep_len(result)
ep_rew = _get_ep_rew(result)
iteration_label = f"Iteration {i + 1}/{num_iterations}"

if ep_len is None:
print(f" {iteration_label}: ep_len_mean not yet available (no completed episodes)")
continue

ep_lens.append(ep_len)
status_len = "OK" if abs(ep_len - EP_LEN_TARGET) <= EP_LEN_TOLERANCE else "FAIL"
rew_str = f"{ep_rew:.2f}" if ep_rew is not None else "N/A"
print(f" {iteration_label}: ep_len_mean={ep_len:.1f} [{status_len}] ep_rew_mean={rew_str}")

algo.stop()

# ------------------------------------------------------------------
# Check 1: every recorded ep_len_mean is within tolerance of 15
# ------------------------------------------------------------------
checks_total += 1
bad_iters = [l for l in ep_lens if abs(l - EP_LEN_TARGET) > EP_LEN_TOLERANCE]
if not ep_lens:
print("\nCHECK 1 FAIL No episodes completed — possible connection/freeze issue.")
elif bad_iters:
print(f"\nCHECK 1 FAIL ep_len_mean out of range {EP_LEN_TARGET}±{EP_LEN_TOLERANCE}: {bad_iters}")
else:
print(f"\nCHECK 1 PASS ep_len_mean ~= {EP_LEN_TARGET} for all {len(ep_lens)} iteration(s).")
checks_passed += 1

# ------------------------------------------------------------------
# Check 2: ep_len_mean does not grow across iterations (no freeze)
# ------------------------------------------------------------------
checks_total += 1
if len(ep_lens) >= 2:
growth = ep_lens[-1] - ep_lens[0]
if growth > EP_LEN_TOLERANCE * 2:
print(f"CHECK 2 FAIL ep_len_mean grew by {growth:.1f} — possible accumulation/freeze.")
else:
print(f"CHECK 2 PASS ep_len_mean stable across iterations (Δ={growth:+.1f}).")
checks_passed += 1
else:
print("CHECK 2 SKIP Need ≥2 iterations with completed episodes to check stability.")
checks_total -= 1

# ------------------------------------------------------------------
# Check 3: no NaN in rewards
# ------------------------------------------------------------------
checks_total += 1
# If we reached here without an exception, rewards did not cause a crash.
# We already filtered NaN in _get_ep_rew, so any NaN would have shown "N/A".
print("CHECK 3 PASS No NaN / KeyError exceptions during training.")
checks_passed += 1

# ------------------------------------------------------------------
# Summary
# ------------------------------------------------------------------
print(f"\n{'=' * 60}")
print(f"Final result: {checks_passed}/{checks_total} checks passed")
if checks_passed == checks_total:
print("*** ALL CHECKS PASSED — staggered-death fix is working correctly. ***")
return True
else:
print("!!! SOME CHECKS FAILED — see details above. !!!")
return False


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Staggered-death integration test")
parser.add_argument("--port", type=int, default=GRPC_PORT,
help=f"gRPC port Unreal is listening on (default: {GRPC_PORT})")
parser.add_argument("--iterations", type=int, default=NUM_ITERATIONS,
help=f"Number of RLlib training iterations (default: {NUM_ITERATIONS})")
args = parser.parse_args()

ok = run(port=args.port, num_iterations=args.iterations)
sys.exit(0 if ok else 1)
Loading