Skip to content

Commit 5729045

Browse files
az0uzAntoine de Maleprade
andauthored
[BugFix] ONNX Export Syntax in Tests and Add Optional Dependencies (#1478)
Co-authored-by: Antoine de Maleprade <antoine.demaleprade@helsing.ai>
1 parent 8d09fa1 commit 5729045

2 files changed

Lines changed: 6 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ tests = [
5858
h5 = ["h5py>=3.8"]
5959
dev = ["pybind11", "ninja"]
6060
typecheck = ["mypy>=1.0.0"]
61+
onnx = ["onnx", "onnxscript", "onnxruntime"]
6162

6263
[tool.setuptools]
6364
include-package-data = false

test/test_compile.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -950,9 +950,7 @@ def test_onnx_export_module(self, tmpdir):
950950
x = torch.randn(3)
951951
y = torch.randn(3)
952952
torch_input = {"x": x, "y": y}
953-
onnx_program = torch.onnx.dynamo_export(tdm, **torch_input)
954-
955-
onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input)
953+
onnx_program = torch.onnx.export(tdm, kwargs=torch_input, dynamo=True)
956954

957955
path = Path(tmpdir) / "file.onnx"
958956
onnx_program.save(str(path))
@@ -969,9 +967,7 @@ def to_numpy(tensor):
969967
else tensor.cpu().numpy()
970968
)
971969

972-
onnxruntime_input = {
973-
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
974-
}
970+
onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()}
975971

976972
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
977973
torch.testing.assert_close(
@@ -986,10 +982,8 @@ def test_onnx_export_seq(self, tmpdir):
986982
x = torch.randn(3)
987983
y = torch.randn(3)
988984
torch_input = {"x": x, "y": y}
989-
torch.onnx.dynamo_export(tdm, x=x, y=y)
990-
onnx_program = torch.onnx.dynamo_export(tdm, **torch_input)
991-
992-
onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input)
985+
torch.onnx.export(tdm, kwargs=torch_input, dynamo=True)
986+
onnx_program = torch.onnx.export(tdm, kwargs=torch_input, dynamo=True)
993987

994988
path = Path(tmpdir) / "file.onnx"
995989
onnx_program.save(str(path))
@@ -1006,9 +1000,7 @@ def to_numpy(tensor):
10061000
else tensor.cpu().numpy()
10071001
)
10081002

1009-
onnxruntime_input = {
1010-
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
1011-
}
1003+
onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()}
10121004

10131005
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
10141006
torch.testing.assert_close(

0 commit comments

Comments
 (0)