99import pytest
1010import torch
1111
12- from tensordict import tensorclass , TensorDict , TensorDictBase
12+ from tensordict import (
13+ from_csv ,
14+ from_json ,
15+ from_pandas ,
16+ from_parquet ,
17+ tensorclass ,
18+ TensorDict ,
19+ TensorDictBase ,
20+ )
21+ from tensordict .tensorclass import NonTensorData
1322
1423_has_pandas = importlib .util .find_spec ("pandas" ) is not None
1524_has_pyarrow = importlib .util .find_spec ("pyarrow" ) is not None
@@ -42,7 +51,6 @@ def test_string_columns(self):
4251 assert td .batch_size == torch .Size ([3 ])
4352 assert td ["age" ].dtype == torch .int64
4453 name_val = td ["name" ]
45- from tensordict .tensorclass import NonTensorData
4654
4755 assert isinstance (name_val , NonTensorData ) or hasattr (name_val , "tolist" )
4856
@@ -133,8 +141,6 @@ def test_categorical_columns(self):
133141 def test_module_level_function (self ):
134142 import pandas as pd
135143
136- from tensordict import from_pandas
137-
138144 df = pd .DataFrame ({"x" : [1 , 2 , 3 ]})
139145 td = from_pandas (df )
140146 assert td .batch_size == torch .Size ([3 ])
@@ -245,8 +251,6 @@ def test_csv_roundtrip(self, tmp_path):
245251 def test_module_level_function (self , tmp_path ):
246252 import pandas as pd
247253
248- from tensordict import from_csv
249-
250254 csv_path = tmp_path / "test.csv"
251255 pd .DataFrame ({"x" : [1 , 2 ]}).to_csv (csv_path , index = False )
252256 td = from_csv (csv_path )
@@ -307,8 +311,6 @@ def test_module_level_function(self, tmp_path):
307311 import pyarrow as pa
308312 import pyarrow .parquet as pq
309313
310- from tensordict import from_parquet
311-
312314 path = tmp_path / "test.parquet"
313315 table = pa .table ({"x" : [1 , 2 ]})
314316 pq .write_table (table , str (path ))
@@ -349,8 +351,6 @@ def test_to_json_lines(self, tmp_path):
349351 assert len (lines ) == 3
350352
351353 def test_module_level_function (self , tmp_path ):
352- from tensordict import from_json
353-
354354 path = tmp_path / "test.json"
355355 path .write_text (json .dumps ([{"a" : 1 }, {"a" : 2 }]))
356356 td = from_json (path )
@@ -410,7 +410,7 @@ def test_csv_roundtrip(self, tmp_path):
410410 tc2 = TabularTensorClass .from_csv (csv_path )
411411 assert isinstance (tc2 , TabularTensorClass )
412412 assert (tc2 .x == tc .x ).all ()
413- assert torch .allclose (tc2 .y , tc .y )
413+ assert torch .allclose (tc2 .y . to ( tc . y . dtype ) , tc .y )
414414
415415 def test_json_roundtrip (self , tmp_path ):
416416 path = tmp_path / "tensorclass.json"
@@ -423,7 +423,7 @@ def test_json_roundtrip(self, tmp_path):
423423 tc2 = TabularTensorClass .from_json (path )
424424 assert isinstance (tc2 , TabularTensorClass )
425425 assert (tc2 .x == tc .x ).all ()
426- assert torch .allclose (tc2 .y , tc .y )
426+ assert torch .allclose (tc2 .y . to ( tc . y . dtype ) , tc .y )
427427
428428 @pytest .mark .skipif (not _has_pyarrow , reason = "pyarrow not found" )
429429 def test_parquet_roundtrip (self , tmp_path ):
@@ -437,4 +437,4 @@ def test_parquet_roundtrip(self, tmp_path):
437437 tc2 = TabularTensorClass .from_parquet (path )
438438 assert isinstance (tc2 , TabularTensorClass )
439439 assert (tc2 .x == tc .x ).all ()
440- assert torch .allclose (tc2 .y , tc .y )
440+ assert torch .allclose (tc2 .y . to ( tc . y . dtype ) , tc .y )
0 commit comments