Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/prime/src/prime_cli/api/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def deploy_adapter(self, adapter_id: str) -> Adapter:
raise APIError(f"Failed to deploy adapter: {e.response.text}")
raise APIError(f"Failed to deploy adapter: {str(e)}")

def deploy_checkpoint(self, checkpoint_id: str) -> Adapter:
"""Deploy a checkpoint by preparing it as an adapter for inference."""
try:
response = self.client.post(f"/rft/checkpoints/{checkpoint_id}/deploy")
return Adapter.model_validate(response.get("adapter"))
except Exception as e:
if hasattr(e, "response") and hasattr(e.response, "text"):
raise APIError(f"Failed to deploy checkpoint: {e.response.text}")
raise APIError(f"Failed to deploy checkpoint: {str(e)}")

def unload_adapter(self, adapter_id: str) -> Adapter:
"""Unload an adapter from inference."""
try:
Expand Down
73 changes: 68 additions & 5 deletions packages/prime/src/prime_cli/commands/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ def _print_inference_usage(base_model: str, adapter_id: str) -> None:
)


def _print_deployment_followup(deployment_status: str) -> None:
console.print()
if deployment_status == "DEPLOYED":
console.print("[dim]The model is deployed and ready for inference.[/dim]")
elif deployment_status == "DEPLOYING":
console.print("[dim]The model is being deployed. This may take a few minutes.[/dim]")
else:
console.print(f"[dim]Deployment status: {deployment_status}[/dim]")
console.print("[dim]Use 'prime deployments list' to check deployment status.[/dim]")


def _print_deployment_success(deployment_status: str) -> None:
if deployment_status == "DEPLOYED":
console.print("[green]Deployment is ready![/green]")
else:
console.print("[green]Deployment initiated successfully![/green]")


@app.command(name="list", epilog=LIST_DEPLOYMENTS_JSON_HELP)
def list_deployments(
team: Optional[str] = typer.Option(None, "--team", "-t", help="Filter by team ID"),
Expand Down Expand Up @@ -172,27 +190,73 @@ def list_deployments(
def create_deployment(
ctx: typer.Context,
model_id: Optional[str] = typer.Argument(None, help="Model ID to deploy"),
checkpoint_id: Optional[str] = typer.Option(
None,
"--checkpoint-id",
help="Deploy a Hosted Training checkpoint by checkpoint ID",
),
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
) -> None:
"""Deploy a model for inference.

Makes the trained model available for inference requests.
Model must be in READY status.
Model must be in READY status. To deploy a checkpoint, pass --checkpoint-id.

Example:

prime deployments create <model_id>

prime deployments create <model_id> --yes

prime deployments create --checkpoint-id <checkpoint_id>
"""
if model_id is None:
if model_id is not None:
model_id = model_id.strip()
if not model_id:
console.print("[red]Error:[/red] MODEL_ID cannot be empty.")
raise typer.Exit(1)

if checkpoint_id is not None:
checkpoint_id = checkpoint_id.strip()
if not checkpoint_id:
console.print("[red]Error:[/red] --checkpoint-id cannot be empty.")
raise typer.Exit(1)

if model_id is not None and checkpoint_id is not None:
console.print("[red]Error:[/red] Use either MODEL_ID or --checkpoint-id, not both.")
raise typer.Exit(1)

if model_id is None and checkpoint_id is None:
console.print(ctx.get_help())
raise typer.Exit(0)

try:
api_client = APIClient()
deployments_client = DeploymentsClient(api_client)

if checkpoint_id is not None:
console.print("[bold]Deploying checkpoint:[/bold]")
console.print(f" Checkpoint ID: {checkpoint_id}")
console.print()

if not yes:
confirm = typer.confirm("Are you sure you want to deploy this checkpoint?")
if not confirm:
console.print("Cancelled.")
raise typer.Exit(0)

adapter = deployments_client.deploy_checkpoint(checkpoint_id)

_print_deployment_success(adapter.deployment_status)
console.print(f"Adapter ID: [cyan]{adapter.id}[/cyan]")
console.print(f"Status: [yellow]{adapter.deployment_status}[/yellow]")
_print_deployment_followup(adapter.deployment_status)

_print_inference_usage(adapter.base_model, adapter.id)
return

assert model_id is not None
Comment thread
kevinjosethomas marked this conversation as resolved.

# Get model to validate status
model = deployments_client.get_adapter(model_id)

Expand Down Expand Up @@ -242,10 +306,9 @@ def create_deployment(
# Deploy the model
updated_model = deployments_client.deploy_adapter(model_id)

console.print("[green]Deployment initiated successfully![/green]")
_print_deployment_success(updated_model.deployment_status)
console.print(f"Status: [yellow]{updated_model.deployment_status}[/yellow]")
console.print("\n[dim]The model is being deployed. This may take a few minutes.[/dim]")
console.print("[dim]Use 'prime deployments list' to check deployment status.[/dim]")
_print_deployment_followup(updated_model.deployment_status)

_print_inference_usage(model.base_model, model.id)

Expand Down
169 changes: 168 additions & 1 deletion packages/prime/tests/test_deployments.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
from types import SimpleNamespace
from typing import Any
from typing import Any, cast

from prime_cli.api.deployments import DeploymentsClient
from prime_cli.client import APIError
from prime_cli.main import app
from prime_cli.utils import strip_ansi
from typer.testing import CliRunner

runner = CliRunner()

TEST_ENV = {"PRIME_API_KEY": "dummy", "PRIME_DISABLE_VERSION_CHECK": "1", "COLUMNS": "200"}


def _adapter_response(
*,
adapter_id: str = "adapter-123",
base_model: str = "meta-llama/Llama-3.1-8B-Instruct",
deployment_status: str = "DEPLOYING",
) -> dict[str, Any]:
return {
"adapter": {
"id": adapter_id,
"displayName": "Checkpoint Adapter",
"userId": "user-123",
"teamId": None,
"rftRunId": "run-123",
"baseModel": base_model,
"step": 20,
"status": "READY",
"deploymentStatus": deployment_status,
"deployedAt": None,
"deploymentError": None,
"createdAt": "2026-01-01T00:00:00Z",
"updatedAt": "2026-01-01T00:00:00Z",
},
"message": "Checkpoint adapter deployment started",
}


def test_deployments_create_prints_chat_and_api_key_commands(monkeypatch) -> None:
monkeypatch.setenv("PRIME_API_KEY", "dummy")
Expand Down Expand Up @@ -57,3 +87,140 @@ def deploy_adapter(self, model_id: str) -> Any:
assert "export PRIME_API_KEY=<insert_key_here>" in output
assert "PRIME_API_KEY" in output
assert "curl -X POST" in output


def test_deployments_client_deploy_checkpoint_posts_endpoint() -> None:
captured: dict[str, Any] = {}

class DummyAPIClient:
def post(self, endpoint: str, json: dict[str, Any] | None = None) -> dict:
captured["endpoint"] = endpoint
captured["json"] = json
return _adapter_response()

adapter = DeploymentsClient(cast(Any, DummyAPIClient())).deploy_checkpoint("ckpt-123")

assert captured["endpoint"] == "/rft/checkpoints/ckpt-123/deploy"
assert captured["json"] is None
assert adapter.id == "adapter-123"


def test_deployments_create_checkpoint_prints_adapter_result(monkeypatch) -> None:
monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None))

adapter = SimpleNamespace(
id="adapter-456",
base_model="Qwen/Qwen3.5-0.8B",
deployment_status="DEPLOYING",
)

class DummyDeploymentsClient:
def __init__(self, api_client: Any) -> None:
self.api_client = api_client

def deploy_checkpoint(self, checkpoint_id: str) -> Any:
assert checkpoint_id == "ckpt-456"
return adapter

monkeypatch.setattr("prime_cli.commands.deployments.APIClient", lambda: object())
monkeypatch.setattr(
"prime_cli.commands.deployments.DeploymentsClient",
DummyDeploymentsClient,
)

result = runner.invoke(
app,
["deployments", "create", "--checkpoint-id", "ckpt-456", "--yes"],
env=TEST_ENV,
)
output = strip_ansi(result.output)

assert result.exit_code == 0, result.output
assert "Deploying checkpoint:" in output
assert "Checkpoint ID: ckpt-456" in output
assert "Deployment initiated successfully!" in output
assert "Adapter ID: adapter-456" in output
assert "Status: DEPLOYING" in output
assert '"Qwen/Qwen3.5-0.8B:adapter-456"' in output
assert "prime deployments list" in output


def test_deployments_create_checkpoint_reports_already_deployed_status(monkeypatch) -> None:
monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None))

adapter = SimpleNamespace(
id="adapter-deployed",
base_model="Qwen/Qwen3.5-4B",
deployment_status="DEPLOYED",
)

class DummyDeploymentsClient:
def __init__(self, api_client: Any) -> None:
self.api_client = api_client

def deploy_checkpoint(self, checkpoint_id: str) -> Any:
assert checkpoint_id == "ckpt-deployed"
return adapter

monkeypatch.setattr("prime_cli.commands.deployments.APIClient", lambda: object())
monkeypatch.setattr(
"prime_cli.commands.deployments.DeploymentsClient",
DummyDeploymentsClient,
)

result = runner.invoke(
app,
["deployments", "create", "--checkpoint-id", "ckpt-deployed", "--yes"],
env=TEST_ENV,
)
output = strip_ansi(result.output)

assert result.exit_code == 0, result.output
assert "Deployment is ready!" in output
assert "Deployment initiated successfully!" not in output
assert "Status: DEPLOYED" in output
assert "The model is deployed and ready for inference." in output
assert "The model is being deployed." not in output


def test_deployments_create_checkpoint_rejects_empty_checkpoint_id(monkeypatch) -> None:
monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None))

result = runner.invoke(
app,
["deployments", "create", "--checkpoint-id", "", "--yes"],
env=TEST_ENV,
)
output = strip_ansi(result.output)

assert result.exit_code == 1
assert "Error: --checkpoint-id cannot be empty." in output


def test_deployments_create_checkpoint_surfaces_conflict_errors(monkeypatch) -> None:
monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None))

class DummyDeploymentsClient:
def __init__(self, api_client: Any) -> None:
self.api_client = api_client

def deploy_checkpoint(self, checkpoint_id: str) -> Any:
assert checkpoint_id == "ckpt-busy"
raise APIError("HTTP 409: Checkpoint adapter preparation is already in progress")

monkeypatch.setattr("prime_cli.commands.deployments.APIClient", lambda: object())
monkeypatch.setattr(
"prime_cli.commands.deployments.DeploymentsClient",
DummyDeploymentsClient,
)

result = runner.invoke(
app,
["deployments", "create", "--checkpoint-id", "ckpt-busy", "--yes"],
env=TEST_ENV,
)
output = strip_ansi(result.output)

assert result.exit_code == 1
assert "Error: HTTP 409" in output
assert "Checkpoint adapter preparation is already in progress" in output
Loading