Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/openpi/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseM

def load_pytorch(self, train_config, weight_path: str):
logger.info(f"train_config: {train_config}")
model = pi0_pytorch.PI0Pytorch(config=train_config.model)
if self.model_type not in (ModelType.PI0, ModelType.PI05):
raise ValueError(f"PyTorch checkpoints are only supported for PI0/PI05 models, got {self.model_type}")
model_config = dataclasses.replace(self, dtype=train_config.pytorch_training_precision)
model = pi0_pytorch.PI0Pytorch(config=model_config)
safetensors.torch.load_model(model, weight_path)
return model

Expand Down
56 changes: 56 additions & 0 deletions src/openpi/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from openpi.models import pi0_fast
from openpi.shared import download
from openpi.shared import nnx_utils
from openpi.training import config as _config


def test_pi0_model():
Expand Down Expand Up @@ -75,6 +76,61 @@ def test_pi0_fast_lora_model():
assert len(lora_state_elems) > 0


@pytest.mark.parametrize(
("model_dtype", "training_precision"),
[
("bfloat16", "float32"),
("float32", "bfloat16"),
],
)
def test_load_pytorch_uses_training_precision(monkeypatch, model_dtype, training_precision):
created_configs = []
loaded = []

class DummyPytorchModel:
pass

def fake_pi0_pytorch(config):
created_configs.append(config)
return DummyPytorchModel()

def fake_load_model(model, weight_path):
loaded.append((model, weight_path))

monkeypatch.setattr(_model.pi0_pytorch, "PI0Pytorch", fake_pi0_pytorch)
monkeypatch.setattr(_model.safetensors.torch, "load_model", fake_load_model)

train_config = _config.TrainConfig(
name="test_config",
exp_name="test_run",
model=pi0_config.Pi0Config(dtype=model_dtype),
pytorch_training_precision=training_precision,
)

model = train_config.model.load_pytorch(train_config, "dummy.safetensors")

assert created_configs[0].dtype == training_precision
assert train_config.model.dtype == model_dtype
assert loaded == [(model, "dummy.safetensors")]


def test_load_pytorch_rejects_unsupported_model_type(monkeypatch):
def fail_if_called(*args, **kwargs):
raise AssertionError("PI0Pytorch and load_model should not be called for unsupported model types")

monkeypatch.setattr(_model.pi0_pytorch, "PI0Pytorch", fail_if_called)
monkeypatch.setattr(_model.safetensors.torch, "load_model", fail_if_called)

train_config = _config.TrainConfig(
name="test_config",
exp_name="test_run",
model=pi0_fast.Pi0FASTConfig(),
)

with pytest.raises(ValueError, match="PI0/PI05"):
train_config.model.load_pytorch(train_config, "dummy.safetensors")


@pytest.mark.manual
def test_model_restore():
key = jax.random.key(0)
Expand Down
4 changes: 2 additions & 2 deletions src/openpi/policies/policy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def create_trained_policy(
logging.info("Loading model...")
if is_pytorch:
model = train_config.model.load_pytorch(train_config, weight_path)
model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
model.paligemma_with_expert.to_bfloat16_for_selected_params(train_config.pytorch_training_precision)
else:
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
Expand All @@ -66,7 +66,7 @@ def create_trained_policy(
# Determine the device to use for PyTorch models
if is_pytorch and pytorch_device is None:
try:
import torch
import torch # noqa: PLC0415

pytorch_device = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
Expand Down
51 changes: 51 additions & 0 deletions src/openpi/policies/policy_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,62 @@
from openpi_client import action_chunk_broker
import pytest

from openpi.models import pi0_config
from openpi.policies import aloha_policy
from openpi.policies import policy_config as _policy_config
from openpi.training import config as _config


def test_create_trained_policy_uses_configured_pytorch_precision(monkeypatch, tmp_path):
precision_calls = []
load_calls = []

class DummyPaliGemmaWithExpert:
def to_bfloat16_for_selected_params(self, precision):
precision_calls.append(precision)

class DummyPytorchModel:
def __init__(self):
self.paligemma_with_expert = DummyPaliGemmaWithExpert()
self.device = None
self.eval_called = False

def to(self, device):
self.device = device
return self

def eval(self):
self.eval_called = True

def sample_actions(self, *args, **kwargs):
raise AssertionError("sample_actions should not be called when constructing the policy")

dummy_model = DummyPytorchModel()

def fake_load_pytorch(self, train_config, weight_path):
load_calls.append((self, train_config, weight_path))
return dummy_model

monkeypatch.setattr(pi0_config.Pi0Config, "load_pytorch", fake_load_pytorch)

checkpoint_dir = tmp_path / "checkpoint"
checkpoint_dir.mkdir()
(checkpoint_dir / "model.safetensors").touch()
train_config = _config.TrainConfig(
name="test_config",
exp_name="test_run",
model=pi0_config.Pi0Config(),
pytorch_training_precision="float32",
)

_policy_config.create_trained_policy(train_config, checkpoint_dir, norm_stats={}, pytorch_device="cpu")

assert load_calls == [(train_config.model, train_config, str(checkpoint_dir / "model.safetensors"))]
assert precision_calls == ["float32"]
assert dummy_model.device == "cpu"
assert dummy_model.eval_called


@pytest.mark.manual
def test_infer():
config = _config.get_config("pi0_aloha_sim")
Expand Down