Skip to content

Commit 25b99b1

Browse files
committed
Split base factory helpers
1 parent 8c46f0c commit 25b99b1

4 files changed

Lines changed: 456 additions & 326 deletions

File tree

tensordict/_base/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Internal implementation modules for :mod:`tensordict.base`."""
7+
8+
from tensordict._base.factories import (
9+
from_any,
10+
from_csv,
11+
from_dict,
12+
from_h5,
13+
from_json,
14+
from_list,
15+
from_namedtuple,
16+
from_pandas,
17+
from_parquet,
18+
from_struct_array,
19+
from_tuple,
20+
)
21+
22+
__all__ = [
23+
"from_any",
24+
"from_csv",
25+
"from_dict",
26+
"from_h5",
27+
"from_json",
28+
"from_list",
29+
"from_namedtuple",
30+
"from_pandas",
31+
"from_parquet",
32+
"from_struct_array",
33+
"from_tuple",
34+
]

tensordict/_base/factories.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
from collections.abc import Mapping
9+
from typing import Type
10+
11+
import numpy as np
12+
import torch
13+
from tensordict._tensorcollection import TensorCollection
14+
from tensordict.base import TensorDictBase
15+
16+
__all__ = [
17+
"from_any",
18+
"from_csv",
19+
"from_dict",
20+
"from_h5",
21+
"from_json",
22+
"from_list",
23+
"from_namedtuple",
24+
"from_pandas",
25+
"from_parquet",
26+
"from_struct_array",
27+
"from_tuple",
28+
]
29+
30+
31+
def from_any(
32+
obj,
33+
*,
34+
auto_batch_size: bool = False,
35+
batch_dims: int | None = None,
36+
device: torch.device | None = None,
37+
batch_size: torch.Size | None = None,
38+
):
39+
"""Converts any object to a TensorDict.
40+
41+
.. seealso:: :meth:`~tensordict.TensorDictBase.from_any` for more information.
42+
"""
43+
return TensorDictBase.from_any(
44+
obj,
45+
auto_batch_size=auto_batch_size,
46+
batch_dims=batch_dims,
47+
device=device,
48+
batch_size=batch_size,
49+
)
50+
51+
52+
def from_tuple(
53+
obj,
54+
*,
55+
auto_batch_size: bool = False,
56+
batch_dims: int | None = None,
57+
device: torch.device | None = None,
58+
batch_size: torch.Size | None = None,
59+
) -> "TensorDictBase":
60+
"""Converts a tuple to a TensorDict.
61+
62+
.. seealso:: :meth:`TensorDictBase.from_tuple` for more information.
63+
"""
64+
return TensorDictBase.from_tuple(
65+
obj,
66+
auto_batch_size=auto_batch_size,
67+
batch_dims=batch_dims,
68+
device=device,
69+
batch_size=batch_size,
70+
)
71+
72+
73+
def from_namedtuple(
74+
named_tuple,
75+
*,
76+
auto_batch_size: bool = False,
77+
batch_dims: int | None = None,
78+
device: torch.device | None = None,
79+
batch_size: torch.Size | None = None,
80+
) -> "TensorDictBase":
81+
"""Converts a namedtuple to a TensorDict.
82+
83+
.. seealso:: :meth:`TensorDictBase.from_namedtuple` for more information.
84+
"""
85+
from tensordict import TensorDict
86+
87+
return TensorDict.from_namedtuple(
88+
named_tuple,
89+
auto_batch_size=auto_batch_size,
90+
batch_dims=batch_dims,
91+
device=device,
92+
batch_size=batch_size,
93+
)
94+
95+
96+
def from_struct_array(
97+
struct_array,
98+
*,
99+
auto_batch_size: bool = False,
100+
batch_dims: int | None = None,
101+
device: torch.device | None = None,
102+
batch_size: torch.Size | None = None,
103+
) -> "TensorDictBase":
104+
"""Converts a structured numpy array to a TensorDict.
105+
106+
.. seealso:: :meth:`TensorDictBase.from_struct_array` for more information.
107+
108+
Examples:
109+
>>> x = np.array(
110+
... [("Rex", 9, 81.0), ("Fido", 3, 27.0)],
111+
... dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")],
112+
... )
113+
>>> td = from_struct_array(x)
114+
>>> x_recon = td.to_struct_array()
115+
>>> assert (x_recon == x).all()
116+
>>> assert x_recon.shape == x.shape
117+
>>> # Try modifying x age field and check effect on td
118+
>>> x["age"] += 1
119+
>>> assert (td["age"] == np.array([10, 4])).all()
120+
121+
"""
122+
return TensorDictBase.from_struct_array(
123+
struct_array,
124+
auto_batch_size=auto_batch_size,
125+
batch_dims=batch_dims,
126+
device=device,
127+
batch_size=batch_size,
128+
)
129+
130+
131+
def from_list(
132+
input: list[TensorCollection | Mapping],
133+
*,
134+
auto_batch_size: bool = False,
135+
batch_dims: int | None = None,
136+
device: torch.device | None = None,
137+
batch_size: torch.Size | None = None,
138+
cls: Type | None = None,
139+
lazy_stack: bool = None,
140+
) -> TensorCollection:
141+
"""Converts a list of dictionaries or TensorDicts to a TensorDict.
142+
143+
.. seealso:: :meth:`TensorDictBase.from_dict` for more information.
144+
"""
145+
if cls is not None:
146+
cls = TensorDictBase
147+
return cls.from_list(
148+
input,
149+
auto_batch_size=auto_batch_size,
150+
batch_dims=batch_dims,
151+
device=device,
152+
batch_size=batch_size,
153+
type=type,
154+
lazy_stack=lazy_stack,
155+
)
156+
157+
158+
def from_dict(
159+
d,
160+
*,
161+
auto_batch_size: bool = False,
162+
batch_dims: int | None = None,
163+
device: torch.device | None = None,
164+
batch_size: torch.Size | None = None,
165+
) -> "TensorDictBase":
166+
"""Converts a dictionary to a TensorDict.
167+
168+
.. seealso:: :meth:`TensorDictBase.from_dict` for more information.
169+
170+
171+
Examples:
172+
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
173+
>>> print(from_dict(input_dict))
174+
TensorDict(
175+
fields={
176+
a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
177+
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
178+
batch_size=torch.Size([3]),
179+
device=None,
180+
is_shared=False)
181+
>>> # nested dict: the nested TensorDict can have a different batch-size
182+
>>> # as long as its leading dims match.
183+
>>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}
184+
>>> print(from_dict(input_dict))
185+
TensorDict(
186+
fields={
187+
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
188+
b: TensorDict(
189+
fields={
190+
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
191+
batch_size=torch.Size([3, 4]),
192+
device=None,
193+
is_shared=False)},
194+
batch_size=torch.Size([3]),
195+
device=None,
196+
is_shared=False)
197+
>>> # we can also use this to work out the batch sie of a tensordict
198+
>>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, [])
199+
>>> print(
200+
from_dict(input_td))
201+
TensorDict(
202+
fields={
203+
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
204+
b: TensorDict(
205+
fields={
206+
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
207+
batch_size=torch.Size([3, 4]),
208+
device=None,
209+
is_shared=False)},
210+
batch_size=torch.Size([3]),
211+
device=None,
212+
is_shared=False)
213+
214+
"""
215+
from tensordict import TensorDict
216+
217+
return TensorDict.from_dict(
218+
d,
219+
auto_batch_size=auto_batch_size,
220+
batch_dims=batch_dims,
221+
device=device,
222+
batch_size=batch_size,
223+
)
224+
225+
226+
def from_h5(
227+
h5_file,
228+
*,
229+
auto_batch_size: bool = False,
230+
batch_dims: int | None = None,
231+
device: torch.device | None = None,
232+
batch_size: torch.Size | None = None,
233+
) -> "TensorDictBase":
234+
"""Converts an HDF5 file to a TensorDict.
235+
236+
.. seealso:: :meth:`TensorDictBase.from_h5` for more information.
237+
"""
238+
from tensordict import TensorDict
239+
240+
return TensorDict.from_h5(
241+
h5_file,
242+
auto_batch_size=auto_batch_size,
243+
batch_dims=batch_dims,
244+
device=device,
245+
batch_size=batch_size,
246+
)
247+
248+
249+
def from_pandas(
250+
dataframe,
251+
*,
252+
auto_batch_size: bool = False,
253+
batch_dims: int | None = None,
254+
device: torch.device | None = None,
255+
batch_size: torch.Size | None = None,
256+
separator: str | None = None,
257+
dtype: torch.dtype | None = None,
258+
) -> "TensorDictBase":
259+
"""Converts a pandas DataFrame to a TensorDict.
260+
261+
.. seealso:: :meth:`TensorDictBase.from_pandas` for more information.
262+
"""
263+
return TensorDictBase.from_pandas(
264+
dataframe,
265+
auto_batch_size=auto_batch_size,
266+
batch_dims=batch_dims,
267+
device=device,
268+
batch_size=batch_size,
269+
separator=separator,
270+
dtype=dtype,
271+
)
272+
273+
274+
def from_csv(
275+
path,
276+
*,
277+
auto_batch_size: bool = False,
278+
batch_dims: int | None = None,
279+
device: torch.device | None = None,
280+
batch_size: torch.Size | None = None,
281+
separator: str | None = None,
282+
dtype: torch.dtype | None = None,
283+
**kwargs,
284+
) -> "TensorDictBase":
285+
"""Creates a TensorDict from a CSV file.
286+
287+
.. seealso:: :meth:`TensorDictBase.from_csv` for more information.
288+
"""
289+
return TensorDictBase.from_csv(
290+
path,
291+
auto_batch_size=auto_batch_size,
292+
batch_dims=batch_dims,
293+
device=device,
294+
batch_size=batch_size,
295+
separator=separator,
296+
dtype=dtype,
297+
**kwargs,
298+
)
299+
300+
301+
def from_parquet(
302+
path,
303+
*,
304+
auto_batch_size: bool = False,
305+
batch_dims: int | None = None,
306+
device: torch.device | None = None,
307+
batch_size: torch.Size | None = None,
308+
separator: str | None = None,
309+
dtype: torch.dtype | None = None,
310+
columns: list[str] | None = None,
311+
**kwargs,
312+
) -> "TensorDictBase":
313+
"""Creates a TensorDict from a Parquet file.
314+
315+
.. seealso:: :meth:`TensorDictBase.from_parquet` for more information.
316+
"""
317+
return TensorDictBase.from_parquet(
318+
path,
319+
auto_batch_size=auto_batch_size,
320+
batch_dims=batch_dims,
321+
device=device,
322+
batch_size=batch_size,
323+
separator=separator,
324+
dtype=dtype,
325+
columns=columns,
326+
**kwargs,
327+
)
328+
329+
330+
def from_json(
331+
path,
332+
*,
333+
auto_batch_size: bool = False,
334+
batch_dims: int | None = None,
335+
device: torch.device | None = None,
336+
batch_size: torch.Size | None = None,
337+
separator: str | None = None,
338+
dtype: torch.dtype | None = None,
339+
lines: bool = False,
340+
**kwargs,
341+
) -> "TensorDictBase":
342+
"""Creates a TensorDict from a JSON file.
343+
344+
.. seealso:: :meth:`TensorDictBase.from_json` for more information.
345+
"""
346+
return TensorDictBase.from_json(
347+
path,
348+
auto_batch_size=auto_batch_size,
349+
batch_dims=batch_dims,
350+
device=device,
351+
batch_size=batch_size,
352+
separator=separator,
353+
dtype=dtype,
354+
lines=lines,
355+
**kwargs,
356+
)
357+
358+
359+
for _name in __all__:
360+
globals()[_name].__module__ = "tensordict.base"

0 commit comments

Comments
 (0)