Skip to content

Commit d871d22

Browse files
committed
Fix tabular imports and CI deps
1 parent 56216f6 commit d871d22

5 files changed

Lines changed: 32 additions & 42 deletions

File tree

.github/unittest/linux/scripts/environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ dependencies:
2222
- ninja
2323
- numpy<2.0.0
2424
- mosaicml-streaming
25+
- pandas
26+
- pyarrow
2527
- redis

.github/unittest/linux/scripts/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ printf "* Installing tensordict\n"
5353
# then install tensordict without resolving dependencies to avoid any solver changing
5454
# the PyTorch build (stable vs nightly).
5555
python -m pip install -U packaging pyvers importlib_metadata
56-
python -m pip install redis
56+
python -m pip install redis pandas pyarrow
5757
python -m pip install -e . --no-deps
5858

5959
# smoke test

tensordict/base.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@
5757
from tensordict._nestedkey import NestedKey
5858
from tensordict._tensorcollection import TensorCollection
5959
from tensordict.memmap import MemoryMappedTensor
60+
from tensordict.tabular import (
61+
_columns_to_tensordict,
62+
_dataframe_to_tensordict,
63+
_read_csv,
64+
_read_json,
65+
_read_parquet,
66+
_tensordict_to_dataframe,
67+
_write_csv,
68+
_write_json,
69+
_write_parquet,
70+
)
6071
from tensordict.utils import (
6172
_as_context_manager,
6273
_CloudpickleWrapper,
@@ -15183,8 +15194,6 @@ def from_pandas(
1518315194
device=None,
1518415195
is_shared=False)
1518515196
"""
15186-
from tensordict.tabular import _dataframe_to_tensordict
15187-
1518815197
if cls is TensorDictBase:
1518915198
from tensordict._td import TensorDict
1519015199

@@ -15230,8 +15239,6 @@ def to_pandas(self, *, separator: str | None = None):
1523015239
1 1 0.0
1523115240
2 2 0.0
1523215241
"""
15233-
from tensordict.tabular import _tensordict_to_dataframe
15234-
1523515242
return _tensordict_to_dataframe(self, separator=separator)
1523615243

1523715244
@classmethod
@@ -15278,8 +15285,6 @@ def from_csv(
1527815285
>>> td = TensorDict.from_csv("data.csv")
1527915286
>>> td = TensorDict.from_csv("data.csv", separator=".", dtype=torch.float32)
1528015287
"""
15281-
from tensordict.tabular import _columns_to_tensordict, _read_csv
15282-
1528315288
if cls is TensorDictBase:
1528415289
from tensordict._td import TensorDict
1528515290

@@ -15315,8 +15320,6 @@ def to_csv(self, path, *, separator: str | None = None, **kwargs):
1531515320
**kwargs: Additional keyword arguments forwarded to
1531615321
``pandas.DataFrame.to_csv``.
1531715322
"""
15318-
from tensordict.tabular import _write_csv
15319-
1532015323
_write_csv(self, path, separator=separator, **kwargs)
1532115324

1532215325
@classmethod
@@ -15367,8 +15370,6 @@ def from_parquet(
1536715370
>>> td = TensorDict.from_parquet("data.parquet")
1536815371
>>> td = TensorDict.from_parquet("data.parquet", columns=["obs", "reward"])
1536915372
"""
15370-
from tensordict.tabular import _columns_to_tensordict, _read_parquet
15371-
1537215373
if cls is TensorDictBase:
1537315374
from tensordict._td import TensorDict
1537415375

@@ -15404,8 +15405,6 @@ def to_parquet(self, path, *, separator: str | None = None, **kwargs):
1540415405
**kwargs: Additional keyword arguments forwarded to the Parquet
1540515406
writer.
1540615407
"""
15407-
from tensordict.tabular import _write_parquet
15408-
1540915408
_write_parquet(self, path, separator=separator, **kwargs)
1541015409

1541115410
@classmethod
@@ -15459,8 +15458,6 @@ def from_json(
1545915458
>>> td = TensorDict.from_json("data.json")
1546015459
>>> td = TensorDict.from_json("data.jsonl", lines=True)
1546115460
"""
15462-
from tensordict.tabular import _columns_to_tensordict, _read_json
15463-
1546415461
if cls is TensorDictBase:
1546515462
from tensordict._td import TensorDict
1546615463

@@ -15503,8 +15500,6 @@ def to_json(
1550315500
**kwargs: Additional keyword arguments forwarded to the JSON
1550415501
writer.
1550515502
"""
15506-
from tensordict.tabular import _write_json
15507-
1550815503
_write_json(self, path, separator=separator, lines=lines, **kwargs)
1550915504

1551015505
def to_h5(

tensordict/tabular.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,21 @@
88

99
from __future__ import annotations
1010

11+
import importlib.util
1112
from pathlib import Path
1213
from typing import Any
1314

1415
import numpy as np
1516
import torch
17+
from tensordict._tensorcollection import TensorCollection
18+
from tensordict.utils import is_non_tensor
1619

1720

1821
def _has_pandas() -> bool:
19-
import importlib.util
20-
2122
return importlib.util.find_spec("pandas") is not None
2223

2324

2425
def _has_pyarrow() -> bool:
25-
import importlib.util
26-
2726
return importlib.util.find_spec("pyarrow") is not None
2827

2928

@@ -51,11 +50,9 @@ def _unflatten_columns(flat_dict: dict, separator: str) -> dict:
5150

5251
def _flatten_keys(td, separator: str) -> dict[str, Any]:
5352
"""Flatten a TensorDict into a dict with separated key names."""
54-
from tensordict.base import _is_tensor_collection, is_non_tensor
55-
5653
result = {}
5754
for key, value in td.items():
58-
if _is_tensor_collection(type(value)) and not is_non_tensor(value):
55+
if isinstance(value, TensorCollection) and not is_non_tensor(value):
5956
sub = _flatten_keys(value, separator)
6057
for sub_key, sub_val in sub.items():
6158
result[f"{key}{separator}{sub_key}"] = sub_val
@@ -132,14 +129,12 @@ def _tensordict_to_dataframe(td, *, separator: str | None):
132129
"""Convert a TensorDict to a pandas DataFrame."""
133130
import pandas as pd
134131

135-
from tensordict.base import _is_tensor_collection, is_non_tensor
136-
137132
if separator is not None:
138133
flat = _flatten_keys(td, separator)
139134
else:
140135
flat = {}
141136
for key, value in td.items():
142-
if _is_tensor_collection(type(value)) and not is_non_tensor(value):
137+
if isinstance(value, TensorCollection) and not is_non_tensor(value):
143138
raise ValueError(
144139
f"Nested TensorDict at key '{key}' requires a separator parameter "
145140
"to flatten to DataFrame columns. Use to_pandas(separator='.')."
@@ -319,8 +314,6 @@ def _write_json(td, path, separator: str | None, lines: bool = False, **kwargs):
319314
else:
320315
import json
321316

322-
from tensordict.base import is_non_tensor
323-
324317
if separator is not None:
325318
flat = _flatten_keys(td, separator)
326319
else:

test/test_tabular.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@
99
import pytest
1010
import torch
1111

12-
from tensordict import tensorclass, TensorDict, TensorDictBase
12+
from tensordict import (
13+
from_csv,
14+
from_json,
15+
from_pandas,
16+
from_parquet,
17+
tensorclass,
18+
TensorDict,
19+
TensorDictBase,
20+
)
21+
from tensordict.tensorclass import NonTensorData
1322

1423
_has_pandas = importlib.util.find_spec("pandas") is not None
1524
_has_pyarrow = importlib.util.find_spec("pyarrow") is not None
@@ -42,7 +51,6 @@ def test_string_columns(self):
4251
assert td.batch_size == torch.Size([3])
4352
assert td["age"].dtype == torch.int64
4453
name_val = td["name"]
45-
from tensordict.tensorclass import NonTensorData
4654

4755
assert isinstance(name_val, NonTensorData) or hasattr(name_val, "tolist")
4856

@@ -133,8 +141,6 @@ def test_categorical_columns(self):
133141
def test_module_level_function(self):
134142
import pandas as pd
135143

136-
from tensordict import from_pandas
137-
138144
df = pd.DataFrame({"x": [1, 2, 3]})
139145
td = from_pandas(df)
140146
assert td.batch_size == torch.Size([3])
@@ -245,8 +251,6 @@ def test_csv_roundtrip(self, tmp_path):
245251
def test_module_level_function(self, tmp_path):
246252
import pandas as pd
247253

248-
from tensordict import from_csv
249-
250254
csv_path = tmp_path / "test.csv"
251255
pd.DataFrame({"x": [1, 2]}).to_csv(csv_path, index=False)
252256
td = from_csv(csv_path)
@@ -307,8 +311,6 @@ def test_module_level_function(self, tmp_path):
307311
import pyarrow as pa
308312
import pyarrow.parquet as pq
309313

310-
from tensordict import from_parquet
311-
312314
path = tmp_path / "test.parquet"
313315
table = pa.table({"x": [1, 2]})
314316
pq.write_table(table, str(path))
@@ -349,8 +351,6 @@ def test_to_json_lines(self, tmp_path):
349351
assert len(lines) == 3
350352

351353
def test_module_level_function(self, tmp_path):
352-
from tensordict import from_json
353-
354354
path = tmp_path / "test.json"
355355
path.write_text(json.dumps([{"a": 1}, {"a": 2}]))
356356
td = from_json(path)
@@ -410,7 +410,7 @@ def test_csv_roundtrip(self, tmp_path):
410410
tc2 = TabularTensorClass.from_csv(csv_path)
411411
assert isinstance(tc2, TabularTensorClass)
412412
assert (tc2.x == tc.x).all()
413-
assert torch.allclose(tc2.y, tc.y)
413+
assert torch.allclose(tc2.y.to(tc.y.dtype), tc.y)
414414

415415
def test_json_roundtrip(self, tmp_path):
416416
path = tmp_path / "tensorclass.json"
@@ -423,7 +423,7 @@ def test_json_roundtrip(self, tmp_path):
423423
tc2 = TabularTensorClass.from_json(path)
424424
assert isinstance(tc2, TabularTensorClass)
425425
assert (tc2.x == tc.x).all()
426-
assert torch.allclose(tc2.y, tc.y)
426+
assert torch.allclose(tc2.y.to(tc.y.dtype), tc.y)
427427

428428
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not found")
429429
def test_parquet_roundtrip(self, tmp_path):
@@ -437,4 +437,4 @@ def test_parquet_roundtrip(self, tmp_path):
437437
tc2 = TabularTensorClass.from_parquet(path)
438438
assert isinstance(tc2, TabularTensorClass)
439439
assert (tc2.x == tc.x).all()
440-
assert torch.allclose(tc2.y, tc.y)
440+
assert torch.allclose(tc2.y.to(tc.y.dtype), tc.y)

0 commit comments

Comments
 (0)