Skip to content

Commit a0b6416

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 686d0e2 commit a0b6416

1 file changed

Lines changed: 13 additions & 16 deletions

File tree

tensordict/utils.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -506,23 +506,20 @@ def _dtype(tensor: Tensor) -> torch.dtype:
506506

507507

508508
def _get_item(tensor: Tensor, index: IndexType) -> Tensor:
509-
if isinstance(tensor, Tensor):
510-
try:
511-
return tensor[index]
512-
except IndexError as err:
513-
# try to map list index to tensor, and assess type. If bool, we
514-
# likely have a nested list of booleans which is not supported by pytorch
515-
if _is_lis_of_list_of_bools(index):
516-
index = torch.tensor(index, device=tensor.device)
517-
if index.dtype is torch.bool:
518-
raise RuntimeError(
519-
"Indexing a tensor with a nested list of boolean values is "
520-
"not supported by PyTorch.",
521-
)
522-
return tensor[index]
523-
raise err
524-
else:
509+
try:
525510
return tensor[index]
511+
except IndexError as err:
512+
# try to map list index to tensor, and assess type. If bool, we
513+
# likely have a nested list of booleans which is not supported by pytorch
514+
if _is_lis_of_list_of_bools(index):
515+
index = torch.tensor(index, device=tensor.device)
516+
if index.dtype is torch.bool:
517+
raise RuntimeError(
518+
"Indexing a tensor with a nested list of boolean values is "
519+
"not supported by PyTorch.",
520+
)
521+
return tensor[index]
522+
raise err
526523

527524

528525
def _set_item(

0 commit comments

Comments
 (0)