diff --git a/src/art/dev/get_model_config.py b/src/art/dev/get_model_config.py index c48f2cbd3..6f1004e00 100644 --- a/src/art/dev/get_model_config.py +++ b/src/art/dev/get_model_config.py @@ -48,6 +48,13 @@ def get_model_config( model=base_model, ) engine_args.update(config.get("engine_args", {})) + if ( + dedicated + and len(config["inference_gpu_ids"]) > 1 + and "tensor_parallel_size" not in engine_args + and "pipeline_parallel_size" not in engine_args + ): + engine_args["tensor_parallel_size"] = len(config["inference_gpu_ids"]) init_args.update(config.get("init_args", {})) if last_checkpoint_dir := get_last_checkpoint_dir(output_dir): init_args["model_name"] = last_checkpoint_dir diff --git a/src/art/dev/validate.py b/src/art/dev/validate.py index 56e91c1df..b7c2b9610 100644 --- a/src/art/dev/validate.py +++ b/src/art/dev/validate.py @@ -15,6 +15,19 @@ def _rollout_weights_mode(config: InternalModelConfig) -> RolloutWeightsMode: raise ValueError("rollout_weights_mode must be either 'lora' or 'merged'") +def _engine_parallel_size(config: InternalModelConfig) -> int: + engine_args = config.get("engine_args", {}) + tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) + pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) + tp = 1 if tensor_parallel_size is None else int(tensor_parallel_size) + pp = 1 if pipeline_parallel_size is None else int(pipeline_parallel_size) + if tp < 1 or pp < 1: + raise ValueError( + "engine_args tensor_parallel_size and pipeline_parallel_size must be positive" + ) + return tp * pp + + def validate_dedicated_config(config: InternalModelConfig) -> None: """Validate dedicated mode GPU configuration. @@ -57,10 +70,14 @@ def validate_dedicated_config(config: InternalModelConfig) -> None: if set(trainer_gpu_ids) & set(inference_gpu_ids): raise ValueError("trainer_gpu_ids and inference_gpu_ids must not overlap") - if len(inference_gpu_ids) > 1: - raise ValueError( - "Multi-GPU inference not yet supported; inference_gpu_ids must have exactly one GPU" - ) + engine_args = config.get("engine_args", {}) + if "tensor_parallel_size" in engine_args or "pipeline_parallel_size" in engine_args: + inference_parallel_size = _engine_parallel_size(config) + if inference_parallel_size != len(inference_gpu_ids): + raise ValueError( + "Dedicated inference GPU count must match engine_args " + "tensor_parallel_size * pipeline_parallel_size" + ) if trainer_gpu_ids[0] != 0: raise ValueError( diff --git a/src/art/pipeline_trainer/yes_no_maybe_pipeline_megatron.py b/src/art/pipeline_trainer/yes_no_maybe_pipeline_megatron.py new file mode 100644 index 000000000..899f9ba87 --- /dev/null +++ b/src/art/pipeline_trainer/yes_no_maybe_pipeline_megatron.py @@ -0,0 +1,165 @@ +"""Minimal yes/no/maybe RL training example using PipelineTrainer.""" + +from __future__ import annotations + +import asyncio +from datetime import datetime +from functools import partial +from itertools import cycle, permutations +import re + +from dotenv import load_dotenv + +import art +from art.megatron import MegatronBackend +from art.pipeline_trainer import PipelineTrainer + +# Training config +BASE_MODEL = "Qwen/Qwen3.5-4B" # or Qwen/Qwen3-4B-Instruct-2507 +MODEL_NAME = "pipeline-yes-no-maybe" +PROJECT = "yes-no-maybe-pipeline" +ROLLOUTS_PER_SCENARIO = 32 +MAX_TOKENS = 5 +MAX_STEPS = 20 +EVAL_TRAJECTORY_COUNT = 30 +EVAL_EVERY_N_STEPS = 2 +PACKED_SEQUENCE_LENGTH = 1024 + + +def build_scenarios() -> list[dict]: + """Generate all scenario variations.""" + scenarios: list[dict] = [] + for prefix in ["respond", "just respond"]: + for use_quotes in [True, False]: + for n in [3, 2]: + for words in permutations(["yes", "no", "maybe"], n): + quoted = [f"'{w}'" if use_quotes else w for w in words] + if len(words) == 3: + body = ", ".join(quoted) + else: + body = " or ".join(quoted) + scenarios.append({"prompt": f"{prefix} with {body}"}) + return scenarios + + +def reward_for_answer(text: str) -> float: + """Score: maybe=1.0, no=0.75, yes=0.5, other=0.0.""" + if not text: + return 0.0 + first_word = re.split(r"\s+", text.strip().lower())[0].strip(".,!?:;\"'()[]{}") + return {"maybe": 1.0, "no": 0.75, "yes": 0.5}.get(first_word, 0.0) + + +async def eval_fn( + model: art.TrainableModel, + step: int, + _config: None, + *, + scenarios: list[dict], +) -> list[art.Trajectory]: + trajectories: list[art.Trajectory] = [] + openai_client = model.openai_client() + model_name = model.get_inference_name(step) + for scenario in scenarios: + messages: art.Messages = [{"role": "user", "content": scenario["prompt"]}] + response = await openai_client.chat.completions.create( + messages=messages, + model=model_name, + max_tokens=MAX_TOKENS, + n=1, + ) + choice = response.choices[0] + trajectories.append( + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(choice.message.content or ""), + ) + ) + return trajectories + + +async def rollout_fn(model, scenario, _config) -> art.TrajectoryGroup: + """Single inference call returns N completions for the group.""" + messages: art.Messages = [{"role": "user", "content": scenario["prompt"]}] + response = await model.openai_client().chat.completions.create( + messages=messages, + model=model.get_inference_name(), + max_tokens=MAX_TOKENS, + n=ROLLOUTS_PER_SCENARIO, + ) + return art.TrajectoryGroup( + [ + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(choice.message.content or ""), + ) + for choice in response.choices + ] + ) + + +async def main() -> None: + load_dotenv() + + model_name = f"{MODEL_NAME}-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + + print("Initializing MegatronBackend") + backend = MegatronBackend() + + print(f"Initializing TrainableModel: {model_name}") + model = art.TrainableModel( + name=model_name, + project=PROJECT, + base_model=BASE_MODEL, + _internal_config=art.dev.InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1, 2], + rollout_weights_mode="merged", + engine_args=art.dev.EngineArgs( + max_model_len=PACKED_SEQUENCE_LENGTH, + enforce_eager=True, + ), + init_args=art.dev.InitArgs( + max_seq_length=PACKED_SEQUENCE_LENGTH, + load_in_4bit=False, + load_in_16bit=True, + ), + ), + ) + + print("Registering model with backend") + await model.register(backend) + print("Model registered") + + base_scenarios = build_scenarios() + scenarios = cycle(base_scenarios) + eval_scenarios = base_scenarios[:EVAL_TRAJECTORY_COUNT] + + eval_callback = partial(eval_fn, scenarios=eval_scenarios) + + trainer = PipelineTrainer( + model=model, + backend=backend, + rollout_fn=rollout_fn, + scenarios=scenarios, + config=None, + learning_rate=5e-5, + loss_fn="cispo", + eval_fn=eval_callback, + packed_sequence_length=PACKED_SEQUENCE_LENGTH, + max_steps=MAX_STEPS, + eval_every_n_steps=EVAL_EVERY_N_STEPS, + eval_at_start=False, + total_scenarios=None, + ) + + print( + f"Training {model_name}: {MAX_STEPS} steps, " + f"{len(base_scenarios)} unique scenarios (cycling)" + ) + await trainer.train() + await backend.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/unit/test_dedicated_config.py b/tests/unit/test_dedicated_config.py index adb3cbe72..fa1da603b 100644 --- a/tests/unit/test_dedicated_config.py +++ b/tests/unit/test_dedicated_config.py @@ -78,10 +78,36 @@ def test_overlapping_gpu_ids(): def test_multi_gpu_inference(): - with pytest.raises(ValueError, match="Multi-GPU inference not yet supported"): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[0], inference_gpu_ids=[1, 2]) + ) + + +def test_three_gpu_inference(): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[0], inference_gpu_ids=[1, 2, 3]) + ) + + +def test_dedicated_inference_parallel_size_must_match_gpu_count(): + with pytest.raises(ValueError, match="GPU count must match"): validate_dedicated_config( - InternalModelConfig(trainer_gpu_ids=[0], inference_gpu_ids=[1, 2]) + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1, 2], + engine_args={"tensor_parallel_size": 1}, # type: ignore[typeddict-item] + ) + ) + + +def test_dedicated_inference_accepts_explicit_matching_parallel_size(): + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1, 2], + engine_args={"pipeline_parallel_size": 2}, # type: ignore[typeddict-item] ) + ) def test_trainer_not_starting_at_zero(): @@ -224,6 +250,29 @@ def test_get_model_config_dedicated_preserves_user_engine_args(): assert result["engine_args"]["enable_sleep_mode"] is False +def test_get_model_config_multi_gpu_inference_defaults_tensor_parallel(): + with tempfile.TemporaryDirectory() as tmpdir: + config = InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1, 2, 3], + ) + result = get_model_config("test-model", tmpdir, config) + assert result["inference_gpu_ids"] == [1, 2, 3] + assert result["engine_args"]["tensor_parallel_size"] == 3 + + +def test_get_model_config_two_gpu_inference_preserves_user_parallel_args(): + with tempfile.TemporaryDirectory() as tmpdir: + config = InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1, 2], + engine_args={"pipeline_parallel_size": 2}, # type: ignore[typeddict-item] + ) + result = get_model_config("test-model", tmpdir, config) + assert result["engine_args"]["pipeline_parallel_size"] == 2 + assert "tensor_parallel_size" not in result["engine_args"] + + def test_get_model_config_preserves_rollout_weights_mode(): with tempfile.TemporaryDirectory() as tmpdir: config = InternalModelConfig(