@@ -661,7 +661,6 @@ def __getstate__(self) -> dict[str, Any]:
661661 "_last_op",
662662 "_cache",
663663 "__lock_parents_weakrefs",
664- "_validate_value_cached",
665664 ):
666665 result.pop(key, None)
667666 return result
@@ -3503,8 +3502,6 @@ def dtype(self):
35033502 return self._dtype()
35043503
35053504 def _batch_size_setter(self, new_batch_size: torch.Size) -> None:
3506- if self._validate_value_cached is not None:
3507- delattr(self, "_validate_value_cached")
35083505 if new_batch_size == self.batch_size:
35093506 return
35103507 if self._lazy:
@@ -5768,17 +5765,13 @@ def clear_device_(self) -> Self:
57685765
57695766 """
57705767 self._device = None
5771- if self._validate_value_cached is not None:
5772- delattr(self, "_validate_value_cached")
57735768 for value in self.values():
57745769 if _is_tensor_collection(type(value)):
57755770 value.clear_device_()
57765771 return self
57775772
57785773 def _set_device(self, device: torch.device) -> Self:
57795774 self._device = device
5780- if self._validate_value_cached is not None:
5781- delattr(self, "_validate_value_cached")
57825775 for value in self.values():
57835776 if _is_tensor_collection(type(value)):
57845777 value._set_device(device=device)
@@ -12964,26 +12957,21 @@ def _validate_key(self, key: NestedKey) -> NestedKey:
1296412957 raise KeyError(_GENERIC_NESTED_ERR.format(key))
1296512958 return key
1296612959
12967- _validate_value_cached: str | None = None
12968-
1296912960 @property
1297012961 def _validate_value(self):
1297112962 if is_compiling():
1297212963 return self._validate_value_generic
12973- _validate_value_cached = self._validate_value_cached
12974- if _validate_value_cached is None:
12975- if self.device:
12976- if self.batch_size:
12977- _validate_value_cached = "_validate_value_generic"
12978- else:
12979- _validate_value_cached = "_validate_value_batchfree"
12964+ if self.device:
12965+ if self.batch_size:
12966+ method_name = "_validate_value_generic"
1298012967 else:
12981- if self.batch_size:
12982- _validate_value_cached = "_validate_value_devicefree"
12983- else:
12984- _validate_value_cached = "_validate_value_batchfree_devicefree"
12985- self._validate_value_cached = _validate_value_cached
12986- return getattr(self, _validate_value_cached)
12968+ method_name = "_validate_value_batchfree"
12969+ else:
12970+ if self.batch_size:
12971+ method_name = "_validate_value_devicefree"
12972+ else:
12973+ method_name = "_validate_value_batchfree_devicefree"
12974+ return getattr(self, method_name)
1298712975
1298812976 def _validate_value_generic(
1298912977 self,
0 commit comments