Skip to content

Apple GPU support & float32 dtype #14

@TomaSusi

Description

@TomaSusi

Hi,

Just getting started with MACE but am really digging it! I was excited to see that you support Apple GPUs, but is that only for training? When I try to use a mace_off() or mace_mp() ASE calculator and specify both the dtype and the device, I get an error:

Using MACE-OFF23 MODEL for MACECalculator with /Users/tomasusi/.cache/mace/MACE-OFF23_medium.model
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[95], line 13
---> 13 calc = mace_off(device='mps', default_dtype='float32')

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/foundations_models.py:206](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/foundations_models.py#line=205), in mace_off(model, device, default_dtype, return_raw_model, **kwargs)
    202 if default_dtype == "float32":
    203     print(
    204         "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
    205     )
--> 206 mace_calc = MACECalculator(
    207     model_paths=model, device=device, default_dtype=default_dtype, **kwargs
    208 )
    209 return mace_calc

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py:127](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py#line=126), in MACECalculator.__init__(self, model_paths, device, energy_units_to_eV, length_units_to_A, default_dtype, charges_key, model_type, compile_mode, fullgraph, **kwargs)
    125     self.use_compile = True
    126 else:
--> 127     self.models = [
    128         torch.load(f=model_path, map_location=device)
    129         for model_path in model_paths
    130     ]
    131     self.use_compile = False
    132 for model in self.models:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py:128](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py#line=127), in <listcomp>(.0)
    125     self.use_compile = True
    126 else:
    127     self.models = [
--> 128         torch.load(f=model_path, map_location=device)
    129         for model_path in model_paths
    130     ]
    131     self.use_compile = False
    132 for model in self.models:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1097](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1096), in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1095             except RuntimeError as e:
   1096                 raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
-> 1097         return _load(
   1098             opened_zipfile,
   1099             map_location,
   1100             pickle_module,
   1101             overall_storage=overall_storage,
   1102             **pickle_load_args,
   1103         )
   1104 if mmap:
   1105     f_name = "" if not isinstance(f, str) else f"{f}, "

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1525](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1524), in _load(zip_file, map_location, pickle_module, pickle_file, overall_storage, **pickle_load_args)
   1522 # Needed for tensors where storage device and rebuild tensor device are
   1523 # not connected (wrapper subclasses and tensors rebuilt using numpy)
   1524 torch._utils._thread_local_state.map_location = map_location
-> 1525 result = unpickler.load()
   1526 del torch._utils._thread_local_state.map_location
   1528 torch._utils._validate_loaded_sparse_tensors()

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py:200](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py#line=199), in _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
    197 def _rebuild_tensor_v2(
    198     storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None
    199 ):
--> 200     tensor = _rebuild_tensor(storage, storage_offset, size, stride)
    201     tensor.requires_grad = requires_grad
    202     if metadata:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py:178](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py#line=177), in _rebuild_tensor(storage, storage_offset, size, stride)
    176 def _rebuild_tensor(storage, storage_offset, size, stride):
    177     # first construct a tensor with the correct dtype[/device](http://localhost:8889/device)
--> 178     t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
    179     return t.set_(storage._untyped_storage, storage_offset, size, stride)

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Or maybe this is just a simple bug..? I am running PyTorch 2.4.1.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions