From 0393e995cc3fdc9528a34a69ed9b03463bd589af Mon Sep 17 00:00:00 2001 From: Zack Ulissi Date: Wed, 28 Jan 2026 01:23:40 +0000 Subject: [PATCH 01/10] first stab at formationenergycalculator --- src/quacc/recipes/mlp/_base.py | 44 ++++- src/quacc/recipes/mlp/core.py | 50 ++++- .../recipes/mlp_recipes/test_core_recipes.py | 186 ++++++++++++++++++ 3 files changed, 268 insertions(+), 12 deletions(-) diff --git a/src/quacc/recipes/mlp/_base.py b/src/quacc/recipes/mlp/_base.py index c7a2e1674a..c384b5f742 100644 --- a/src/quacc/recipes/mlp/_base.py +++ b/src/quacc/recipes/mlp/_base.py @@ -5,7 +5,7 @@ from functools import lru_cache, wraps from importlib.util import find_spec from logging import getLogger -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from ase.units import GPa as _GPa_to_eV_per_A3 from monty.dev import requires @@ -57,6 +57,8 @@ def pick_calculator( method: Literal[ "mace-mp", "m3gnet", "chgnet", "tensornet", "sevennet", "orb", "fairchem" ], + use_formation_energy: bool = False, + formation_energy_kwargs: Any = None, **calc_kwargs, ) -> BaseCalculator: """ @@ -71,6 +73,13 @@ def pick_calculator( ---------- method Name of the calculator to use. + use_formation_energy + If True, wrap the calculator with FormationEnergyCalculator to compute + formation energies. Currently only supported for FAIRChem UMA with + task_name='omat'. Default is False. + formation_energy_kwargs + Custom kwargs for the FormationEnergyCalculator wrapper. Only used if + use_formation_energy=True. Default is None. **calc_kwargs Custom kwargs for the underlying calculator. Set a value to `quacc.Remove` to remove a pre-existing key entirely. @@ -78,7 +87,7 @@ def pick_calculator( Returns ------- BaseCalculator - The instantiated calculator + The instantiated calculator (optionally wrapped with FormationEnergyCalculator) """ import torch @@ -139,4 +148,35 @@ def pick_calculator( calc.parameters["version"] = __version__ + # Wrap with FormationEnergyCalculator if requested + if use_formation_energy: + from fairchem.core import FAIRChemCalculator + from fairchem.core.calculate.ase_calculator import FormationEnergyCalculator + + if method.lower() != "fairchem": + raise ValueError( + "Formation energy calculations are currently only supported for " + "FAIRChem UMA with task_name='omat'. Please use method='fairchem' " + "with use_formation_energy=True." + ) + + if not isinstance(calc, FAIRChemCalculator): + raise ValueError( + "Expected FAIRChemCalculator but got a different calculator type." + ) + + # Check that omat task is being used + if not hasattr(calc, "task_name") or calc.task_name != "omat": + raise ValueError( + "Formation energy calculations are only supported for FAIRChem UMA " + "with task_name='omat'. Please ensure you are using " + "FAIRChemCalculator.from_model_checkpoint(..., task_name='omat')." + ) + + # Use provided kwargs or empty dict if None + fe_kwargs = formation_energy_kwargs or {} + + # Wrap with FormationEnergyCalculator using provided kwargs + calc = FormationEnergyCalculator(calculator=calc, **fe_kwargs) + return calc diff --git a/src/quacc/recipes/mlp/core.py b/src/quacc/recipes/mlp/core.py index 1adfe07de8..8c0a48d89f 100644 --- a/src/quacc/recipes/mlp/core.py +++ b/src/quacc/recipes/mlp/core.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from quacc import job from quacc.recipes.mlp._base import pick_calculator @@ -11,7 +11,7 @@ from quacc.utils.dicts import recursive_dict_merge if TYPE_CHECKING: - from typing import Any, Literal + from typing import Literal from ase.atoms import Atoms @@ -23,6 +23,8 @@ def static_job( atoms: Atoms, method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb", "fairchem"], additional_fields: dict[str, Any] | None = None, + use_formation_energy: bool = False, + formation_energy_kwargs: Any = None, **calc_kwargs, ) -> RunSchema: """ @@ -36,11 +38,19 @@ def static_job( Universal ML interatomic potential method to use additional_fields Additional fields to add to the results dictionary. + use_formation_energy + If True, wrap the calculator with FormationEnergyCalculator to compute + formation energies. Currently only supported for FAIRChem with + method='fairchem' and task_name='omat'. Default is False. + formation_energy_kwargs + Custom kwargs for the FormationEnergyCalculator wrapper. Only used if + use_formation_energy=True. Default is None. **calc_kwargs Custom kwargs for the underlying calculator. Set a value to - `quacc.Remove` to remove a pre-existing key entirely. For a list of available - keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`, - `matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`, + `quacc.Remove` to remove a pre-existing key entirely. For a list of + available keys, refer to the `mace.calculators.mace_mp`, + `chgnet.model.dynamics.CHGNetCalculator`, `matgl.ext.ase.M3GNetCalculator`, + `sevenn.sevennet_calculator.SevenNetCalculator`, `orb_models.forcefield.calculator.ORBCalculator`, `fairchem.core.FAIRChemCalculator` calculators. @@ -50,7 +60,12 @@ def static_job( Dictionary of results from [quacc.schemas.ase.Summarize.run][]. See the type-hint for the data structure. """ - calc = pick_calculator(method, **calc_kwargs) + calc = pick_calculator( + method, + use_formation_energy=use_formation_energy, + formation_energy_kwargs=formation_energy_kwargs, + **calc_kwargs, + ) final_atoms = Runner(atoms, calc).run_calc() return Summarize( additional_fields={"name": f"{method} Static"} | (additional_fields or {}) @@ -64,6 +79,8 @@ def relax_job( relax_cell: bool = False, opt_params: OptParams | None = None, additional_fields: dict[str, Any] | None = None, + use_formation_energy: bool = False, + formation_energy_kwargs: Any = None, **calc_kwargs, ) -> OptSchema: """ @@ -82,11 +99,19 @@ def relax_job( of available keys, refer to [quacc.runners.ase.Runner.run_opt][]. additional_fields Additional fields to add to the results dictionary. + use_formation_energy + If True, wrap the calculator with FormationEnergyCalculator to compute + formation energies. Currently only supported for FAIRChem with + method='fairchem' and task_name='omat'. Default is False. + formation_energy_kwargs + Custom kwargs for the FormationEnergyCalculator wrapper. Only used if + use_formation_energy=True. Default is None. **calc_kwargs Custom kwargs for the underlying calculator. Set a value to - `quacc.Remove` to remove a pre-existing key entirely. For a list of available - keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`, - `matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`, + `quacc.Remove` to remove a pre-existing key entirely. For a list of + available keys, refer to the `mace.calculators.mace_mp`, + `chgnet.model.dynamics.CHGNetCalculator`, `matgl.ext.ase.M3GNetCalculator`, + `sevenn.sevennet_calculator.SevenNetCalculator`, `orb_models.forcefield.calculator.ORBCalculator`, `fairchem.core.FAIRChemCalculator` calculators. @@ -99,7 +124,12 @@ def relax_job( opt_defaults = {"fmax": 0.05} opt_flags = recursive_dict_merge(opt_defaults, opt_params) - calc = pick_calculator(method, **calc_kwargs) + calc = pick_calculator( + method, + use_formation_energy=use_formation_energy, + formation_energy_kwargs=formation_energy_kwargs, + **calc_kwargs, + ) dyn = Runner(atoms, calc).run_opt(relax_cell=relax_cell, **opt_flags) diff --git a/tests/core/recipes/mlp_recipes/test_core_recipes.py b/tests/core/recipes/mlp_recipes/test_core_recipes.py index 46d7392d50..0d82a28a09 100644 --- a/tests/core/recipes/mlp_recipes/test_core_recipes.py +++ b/tests/core/recipes/mlp_recipes/test_core_recipes.py @@ -154,3 +154,189 @@ def test_relax_cell_job(tmp_path, monkeypatch, method): assert np.shape(output["results"]["forces"]) == (8, 3) assert output["atoms"] != atoms assert output["atoms"].get_volume() != pytest.approx(atoms.get_volume()) + +@pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") +def test_static_job_formation_energy_fairchem(tmp_path, monkeypatch): + """Test formation energy calculation with FAIRChem UMA omat.""" + monkeypatch.chdir(tmp_path) + _set_dtype(32) + + from huggingface_hub.utils._auth import get_token + + if not get_token(): + pytest.skip("HuggingFace token not available for FAIRChem") + + calc_kwargs = {"name_or_path": "uma-s-1", "task_name": "omat"} + + # Test Cu (elemental system - formation energy should be ~0) + atoms_cu = bulk("Cu") + output_cu = static_job( + atoms_cu, method="fairchem", use_formation_energy=True, **calc_kwargs + ) + # For pure elements, formation energy should be close to zero + assert abs(output_cu["results"]["energy"]) < 0.1 + assert np.shape(output_cu["results"]["forces"]) == (1, 3) + + # Test MgO (binary compound - formation energy should be negative) + atoms_mgo = bulk("MgO", crystalstructure="rocksalt", a=4.2) + output_mgo = static_job( + atoms_mgo, method="fairchem", use_formation_energy=True, **calc_kwargs + ) + # MgO has a substantial negative formation energy + assert output_mgo["results"]["energy"] < -2.0 # per formula unit + assert np.shape(output_mgo["results"]["forces"]) == (2, 3) + + +@pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") +def test_relax_job_formation_energy_fairchem(tmp_path, monkeypatch): + """Test formation energy calculation during relaxation with FAIRChem UMA omat.""" + monkeypatch.chdir(tmp_path) + _set_dtype(32) + + from huggingface_hub.utils._auth import get_token + + if not get_token(): + pytest.skip("HuggingFace token not available for FAIRChem") + + calc_kwargs = {"name_or_path": "uma-s-1", "task_name": "omat"} + + # Test Cu supercell + atoms = bulk("Cu") * (2, 2, 2) + atoms[0].position += 0.1 + output = relax_job( + atoms, method="fairchem", use_formation_energy=True, **calc_kwargs + ) + # 8 Cu atoms, formation energy should be near zero + assert abs(output["results"]["energy"]) < 0.8 # 8 * 0.1 eV tolerance per atom + assert np.shape(output["results"]["forces"]) == (8, 3) + assert output["atoms"] != atoms + + +@pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") +def test_formation_energy_multiple_compounds_fairchem(tmp_path, monkeypatch): + """Test formation energy for multiple compound types with FAIRChem UMA omat.""" + monkeypatch.chdir(tmp_path) + _set_dtype(32) + + from huggingface_hub.utils._auth import get_token + + if not get_token(): + pytest.skip("HuggingFace token not available for FAIRChem") + + calc_kwargs = {"name_or_path": "uma-s-1", "task_name": "omat"} + + # Test NaCl + atoms_nacl = bulk("NaCl", crystalstructure="rocksalt", a=5.64) + output_nacl = static_job( + atoms_nacl, method="fairchem", use_formation_energy=True, **calc_kwargs + ) + # NaCl has negative formation energy + assert output_nacl["results"]["energy"] < -1.0 + + # Test Si + atoms_si = bulk("Si") + output_si = static_job( + atoms_si, method="fairchem", use_formation_energy=True, **calc_kwargs + ) + # Pure Si should have near-zero formation energy + assert abs(output_si["results"]["energy"]) < 0.1 + + # Test Al + atoms_al = bulk("Al") + output_al = static_job( + atoms_al, method="fairchem", use_formation_energy=True, **calc_kwargs + ) + # Pure Al should have near-zero formation energy + assert abs(output_al["results"]["energy"]) < 0.1 + + +@pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") +def test_relax_job_formation_energy_cell_fairchem(tmp_path, monkeypatch): + """Test formation energy calculation with cell relaxation using FAIRChem UMA omat.""" + monkeypatch.chdir(tmp_path) + _set_dtype(32) + + from huggingface_hub.utils._auth import get_token + + if not get_token(): + pytest.skip("HuggingFace token not available for FAIRChem") + + calc_kwargs = {"name_or_path": "uma-s-1", "task_name": "omat"} + + # Test MgO with cell relaxation + atoms_mgo = bulk("MgO", crystalstructure="rocksalt", a=4.2) * (2, 2, 2) + atoms_mgo[0].position += 0.05 + output = relax_job( + atoms_mgo, + method="fairchem", + relax_cell=True, + use_formation_energy=True, + **calc_kwargs, + ) + # Should have relaxed and computed formation energy + assert output["results"]["energy"] < -8.0 # 8 formula units * ~-1 eV per formula unit + assert output["atoms"] != atoms_mgo + + +@pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") +def test_formation_energy_error_without_omat(tmp_path, monkeypatch): + """Test that formation energy raises error when not using omat task.""" + monkeypatch.chdir(tmp_path) + _set_dtype(32) + + from huggingface_hub.utils._auth import get_token + + if not get_token(): + pytest.skip("HuggingFace token not available for FAIRChem") + + # Try to use formation energy without omat task - should raise + atoms = bulk("Cu") + with pytest.raises(ValueError, match="task_name='omat'"): + static_job( + atoms, + method="fairchem", + name_or_path="uma-s-1", + task_name="omol", # Wrong task + use_formation_energy=True, + ) + + +@pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") +def test_formation_energy_error_without_fairchem(tmp_path, monkeypatch): + """Test that formation energy raises error when not using fairchem.""" + monkeypatch.chdir(tmp_path) + + if "mace-mp" not in methods: + pytest.skip("mace-mp not available") + + _set_dtype(64) + + # Try to use formation energy with non-FAIRChem method - should raise + atoms = bulk("Cu") + with pytest.raises(ValueError, match="FAIRChem UMA"): + static_job(atoms, method="mace-mp", use_formation_energy=True) + + +@pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") +def test_static_job_formation_energy_with_kwargs(tmp_path, monkeypatch): + """Test that formation_energy_kwargs are properly passed through.""" + monkeypatch.chdir(tmp_path) + _set_dtype(32) + + from huggingface_hub.utils._auth import get_token + + if not get_token(): + pytest.skip("HuggingFace token not available for FAIRChem") + + atoms = bulk("Cu") + output = static_job( + atoms, + method="fairchem", + name_or_path="uma-s-1", + task_name="omat", + use_formation_energy=True, + formation_energy_kwargs={}, # Pass empty dict as formation_energy_kwargs + ) + # Should succeed and compute formation energy + assert abs(output["results"]["energy"]) < 0.1 + assert np.shape(output["results"]["forces"]) == (1, 3) \ No newline at end of file From 8ced22845ecfd58c86a088d240d00134f126db6a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Jan 2026 01:27:40 +0000 Subject: [PATCH 02/10] pre-commit auto-fixes --- tests/core/recipes/mlp_recipes/test_core_recipes.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/core/recipes/mlp_recipes/test_core_recipes.py b/tests/core/recipes/mlp_recipes/test_core_recipes.py index 0d82a28a09..b434435c0a 100644 --- a/tests/core/recipes/mlp_recipes/test_core_recipes.py +++ b/tests/core/recipes/mlp_recipes/test_core_recipes.py @@ -155,6 +155,7 @@ def test_relax_cell_job(tmp_path, monkeypatch, method): assert output["atoms"] != atoms assert output["atoms"].get_volume() != pytest.approx(atoms.get_volume()) + @pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") def test_static_job_formation_energy_fairchem(tmp_path, monkeypatch): """Test formation energy calculation with FAIRChem UMA omat.""" @@ -274,7 +275,9 @@ def test_relax_job_formation_energy_cell_fairchem(tmp_path, monkeypatch): **calc_kwargs, ) # Should have relaxed and computed formation energy - assert output["results"]["energy"] < -8.0 # 8 formula units * ~-1 eV per formula unit + assert ( + output["results"]["energy"] < -8.0 + ) # 8 formula units * ~-1 eV per formula unit assert output["atoms"] != atoms_mgo @@ -339,4 +342,4 @@ def test_static_job_formation_energy_with_kwargs(tmp_path, monkeypatch): ) # Should succeed and compute formation energy assert abs(output["results"]["energy"]) < 0.1 - assert np.shape(output["results"]["forces"]) == (1, 3) \ No newline at end of file + assert np.shape(output["results"]["forces"]) == (1, 3) From 944536e9d4eb91647fe272d31a8bec83473eb01b Mon Sep 17 00:00:00 2001 From: Zack Ulissi Date: Wed, 28 Jan 2026 01:28:42 +0000 Subject: [PATCH 03/10] small changes to documentation --- src/quacc/recipes/mlp/core.py | 6 ++++-- .../recipes/mlp_recipes/test_core_recipes.py | 20 +++++++++---------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/quacc/recipes/mlp/core.py b/src/quacc/recipes/mlp/core.py index 8c0a48d89f..42c5fdc99e 100644 --- a/src/quacc/recipes/mlp/core.py +++ b/src/quacc/recipes/mlp/core.py @@ -41,7 +41,8 @@ def static_job( use_formation_energy If True, wrap the calculator with FormationEnergyCalculator to compute formation energies. Currently only supported for FAIRChem with - method='fairchem' and task_name='omat'. Default is False. + method='fairchem' and task_name='omat'. Default is False. The formation + energy is returned in eV per formula unit (not eV/atom). formation_energy_kwargs Custom kwargs for the FormationEnergyCalculator wrapper. Only used if use_formation_energy=True. Default is None. @@ -102,7 +103,8 @@ def relax_job( use_formation_energy If True, wrap the calculator with FormationEnergyCalculator to compute formation energies. Currently only supported for FAIRChem with - method='fairchem' and task_name='omat'. Default is False. + method='fairchem' and task_name='omat'. Default is False. The formation + energy is returned in eV per formula unit (not eV/atom). formation_energy_kwargs Custom kwargs for the FormationEnergyCalculator wrapper. Only used if use_formation_energy=True. Default is None. diff --git a/tests/core/recipes/mlp_recipes/test_core_recipes.py b/tests/core/recipes/mlp_recipes/test_core_recipes.py index 0d82a28a09..fbe8172c01 100644 --- a/tests/core/recipes/mlp_recipes/test_core_recipes.py +++ b/tests/core/recipes/mlp_recipes/test_core_recipes.py @@ -173,7 +173,7 @@ def test_static_job_formation_energy_fairchem(tmp_path, monkeypatch): output_cu = static_job( atoms_cu, method="fairchem", use_formation_energy=True, **calc_kwargs ) - # For pure elements, formation energy should be close to zero + # For pure elements, total formation energy should be close to zero (eV) assert abs(output_cu["results"]["energy"]) < 0.1 assert np.shape(output_cu["results"]["forces"]) == (1, 3) @@ -182,8 +182,8 @@ def test_static_job_formation_energy_fairchem(tmp_path, monkeypatch): output_mgo = static_job( atoms_mgo, method="fairchem", use_formation_energy=True, **calc_kwargs ) - # MgO has a substantial negative formation energy - assert output_mgo["results"]["energy"] < -2.0 # per formula unit + # MgO has a substantial negative formation energy (eV per formula unit) + assert output_mgo["results"]["energy"] < -2.0 assert np.shape(output_mgo["results"]["forces"]) == (2, 3) @@ -206,8 +206,8 @@ def test_relax_job_formation_energy_fairchem(tmp_path, monkeypatch): output = relax_job( atoms, method="fairchem", use_formation_energy=True, **calc_kwargs ) - # 8 Cu atoms, formation energy should be near zero - assert abs(output["results"]["energy"]) < 0.8 # 8 * 0.1 eV tolerance per atom + # 8 Cu atoms, total formation energy should be near zero (eV) + assert abs(output["results"]["energy"]) < 0.8 assert np.shape(output["results"]["forces"]) == (8, 3) assert output["atoms"] != atoms @@ -230,7 +230,7 @@ def test_formation_energy_multiple_compounds_fairchem(tmp_path, monkeypatch): output_nacl = static_job( atoms_nacl, method="fairchem", use_formation_energy=True, **calc_kwargs ) - # NaCl has negative formation energy + # NaCl has negative formation energy (eV per formula unit) assert output_nacl["results"]["energy"] < -1.0 # Test Si @@ -238,7 +238,7 @@ def test_formation_energy_multiple_compounds_fairchem(tmp_path, monkeypatch): output_si = static_job( atoms_si, method="fairchem", use_formation_energy=True, **calc_kwargs ) - # Pure Si should have near-zero formation energy + # Pure Si should have near-zero formation energy (eV) assert abs(output_si["results"]["energy"]) < 0.1 # Test Al @@ -246,7 +246,7 @@ def test_formation_energy_multiple_compounds_fairchem(tmp_path, monkeypatch): output_al = static_job( atoms_al, method="fairchem", use_formation_energy=True, **calc_kwargs ) - # Pure Al should have near-zero formation energy + # Pure Al should have near-zero formation energy (eV) assert abs(output_al["results"]["energy"]) < 0.1 @@ -273,8 +273,8 @@ def test_relax_job_formation_energy_cell_fairchem(tmp_path, monkeypatch): use_formation_energy=True, **calc_kwargs, ) - # Should have relaxed and computed formation energy - assert output["results"]["energy"] < -8.0 # 8 formula units * ~-1 eV per formula unit + # Should have relaxed and computed total formation energy (eV for 8 formula units) + assert output["results"]["energy"] < -8.0 assert output["atoms"] != atoms_mgo From c868ebbbca4ce3ae88b0719b2a7870d6f3b38776 Mon Sep 17 00:00:00 2001 From: Zack Ulissi Date: Wed, 28 Jan 2026 20:31:14 +0000 Subject: [PATCH 04/10] add support for different reference types --- src/quacc/recipes/mlp/_base.py | 132 ++++++--- src/quacc/recipes/mlp/core.py | 40 +-- ...-07-mp-elemental-reference-entries.json.gz | Bin 0 -> 6350 bytes .../recipes/mlp_recipes/test_core_recipes.py | 255 ++++++++++++++++-- 4 files changed, 362 insertions(+), 65 deletions(-) create mode 100644 src/quacc/recipes/mlp/references/2023-02-07-mp-elemental-reference-entries.json.gz diff --git a/src/quacc/recipes/mlp/_base.py b/src/quacc/recipes/mlp/_base.py index c384b5f742..f91467ee27 100644 --- a/src/quacc/recipes/mlp/_base.py +++ b/src/quacc/recipes/mlp/_base.py @@ -51,6 +51,86 @@ def wrapped(*args, **kwargs): return wrapped +@lru_cache +def _get_omat24_references() -> dict[str, float]: + """ + Fetch formation energy references for OMAT24-trained models from HuggingFace. + + These references come from https://huggingface.co/facebook/UMA/blob/main/references/form_elem_refs.yaml + + Returns + ------- + dict[str, float] + Dictionary mapping element symbols to reference energies (eV/atom). + """ + from huggingface_hub import hf_hub_download + import yaml + + LOGGER.info("Downloading OMAT24 formation energy references from HuggingFace...") + + # Download the form_elem_refs.yaml file from HuggingFace + refs_file = hf_hub_download( + repo_id="facebook/UMA", + filename="references/form_elem_refs.yaml", + repo_type="model" + ) + + # Load and extract the omat references + with open(refs_file) as f: + refs_data = yaml.safe_load(f) + + omat_refs = refs_data.get("refs", {}).get("omat", {}) + + if not omat_refs: + raise ValueError("Could not find 'refs.omat' in the downloaded reference file.") + + LOGGER.info(f"Loaded OMAT24 references for {len(omat_refs)} elements.") + return omat_refs + + +@lru_cache +def _get_mp20_references() -> dict[str, float]: + """ + Load formation energy references for MP-20 compatible models. + + These references come from matbench-discovery repository: + https://github.com/janosh/matbench-discovery + + Returns + ------- + dict[str, float] + Dictionary mapping element symbols to reference energies (eV/atom). + """ + import gzip + import json + from pathlib import Path + + LOGGER.info("Loading MP-20 formation energy references from local file...") + + # Load from local gzipped JSON file + refs_file = Path(__file__).parent / "references" / "2023-02-07-mp-elemental-reference-entries.json.gz" + + if not refs_file.exists(): + raise FileNotFoundError( + f"MP-20 reference file not found at {refs_file}. " + "Please ensure the file is in src/quacc/recipes/mlp/references/" + ) + + # Load the gzipped JSON file + with gzip.open(refs_file, "rt") as f: + refs_data = json.load(f) + + # Extract element references based on the expected structure + # The file should contain element references + if isinstance(refs_data, dict): + mp20_refs = refs_data + else: + raise ValueError(f"Unexpected format in MP-20 reference file: {type(refs_data)}") + + LOGGER.info(f"Loaded MP-20 references for {len(mp20_refs)} elements.") + return mp20_refs + + @freezeargs @lru_cache def pick_calculator( @@ -58,7 +138,7 @@ def pick_calculator( "mace-mp", "m3gnet", "chgnet", "tensornet", "sevennet", "orb", "fairchem" ], use_formation_energy: bool = False, - formation_energy_kwargs: Any = None, + references: Literal["MP20", "OMAT24"] | None = None, **calc_kwargs, ) -> BaseCalculator: """ @@ -75,11 +155,15 @@ def pick_calculator( Name of the calculator to use. use_formation_energy If True, wrap the calculator with FormationEnergyCalculator to compute - formation energies. Currently only supported for FAIRChem UMA with - task_name='omat'. Default is False. - formation_energy_kwargs - Custom kwargs for the FormationEnergyCalculator wrapper. Only used if - use_formation_energy=True. Default is None. + formation energies. Requires fairchem-core package to be installed. + Supported for all calculator types. Default is False. + references + Formation energy references to use. Only used if use_formation_energy=True. + Options: + - None: Use built-in references from FormationEnergyCalculator (FAIRChem models only) + - "OMAT24": Use OMAT24 references from https://huggingface.co/facebook/UMA + - "MP20": Use MP-20 references from matbench-discovery + Default is None. **calc_kwargs Custom kwargs for the underlying calculator. Set a value to `quacc.Remove` to remove a pre-existing key entirely. @@ -134,7 +218,7 @@ def pick_calculator( from orb_models.forcefield import pretrained from orb_models.forcefield.calculator import ORBCalculator - orb_model = calc_kwargs.get("model", "orb_v2") + orb_model = calc_kwargs.get("model", "orb_v3_conservative_inf_omat") orbff = getattr(pretrained, orb_model)() calc = ORBCalculator(model=orbff, **calc_kwargs) @@ -150,31 +234,19 @@ def pick_calculator( # Wrap with FormationEnergyCalculator if requested if use_formation_energy: - from fairchem.core import FAIRChemCalculator from fairchem.core.calculate.ase_calculator import FormationEnergyCalculator - if method.lower() != "fairchem": - raise ValueError( - "Formation energy calculations are currently only supported for " - "FAIRChem UMA with task_name='omat'. Please use method='fairchem' " - "with use_formation_energy=True." - ) - - if not isinstance(calc, FAIRChemCalculator): - raise ValueError( - "Expected FAIRChemCalculator but got a different calculator type." - ) - - # Check that omat task is being used - if not hasattr(calc, "task_name") or calc.task_name != "omat": - raise ValueError( - "Formation energy calculations are only supported for FAIRChem UMA " - "with task_name='omat'. Please ensure you are using " - "FAIRChemCalculator.from_model_checkpoint(..., task_name='omat')." - ) - - # Use provided kwargs or empty dict if None - fe_kwargs = formation_energy_kwargs or {} + # Determine which reference energies to use + fe_kwargs = {} + + if references == "OMAT24": + # Use OMAT24 references from HuggingFace + fe_kwargs["references"] = _get_omat24_references() + elif references == "MP20": + # Use MP-20 references from local file + fe_kwargs["references"] = _get_mp20_references() + # If references is None, use built-in references from FormationEnergyCalculator + # (works for FAIRChem models with task_name specified) # Wrap with FormationEnergyCalculator using provided kwargs calc = FormationEnergyCalculator(calculator=calc, **fe_kwargs) diff --git a/src/quacc/recipes/mlp/core.py b/src/quacc/recipes/mlp/core.py index 42c5fdc99e..53c57da543 100644 --- a/src/quacc/recipes/mlp/core.py +++ b/src/quacc/recipes/mlp/core.py @@ -24,7 +24,7 @@ def static_job( method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb", "fairchem"], additional_fields: dict[str, Any] | None = None, use_formation_energy: bool = False, - formation_energy_kwargs: Any = None, + references: Literal["MP20", "OMAT24"] | None = None, **calc_kwargs, ) -> RunSchema: """ @@ -40,12 +40,16 @@ def static_job( Additional fields to add to the results dictionary. use_formation_energy If True, wrap the calculator with FormationEnergyCalculator to compute - formation energies. Currently only supported for FAIRChem with - method='fairchem' and task_name='omat'. Default is False. The formation - energy is returned in eV per formula unit (not eV/atom). - formation_energy_kwargs - Custom kwargs for the FormationEnergyCalculator wrapper. Only used if - use_formation_energy=True. Default is None. + formation energies. Requires fairchem-core package to be installed. + Supported for all methods. Default is False. The formation energy is + returned in eV per formula unit (not eV/atom). + references + Formation energy references to use. Only used if use_formation_energy=True. + Options: + - None: Use built-in references from FormationEnergyCalculator (FAIRChem models only) + - "OMAT24": Use OMAT24 references from https://huggingface.co/facebook/UMA + - "MP20": Use MP-20 references from matbench-discovery + Default is None. **calc_kwargs Custom kwargs for the underlying calculator. Set a value to `quacc.Remove` to remove a pre-existing key entirely. For a list of @@ -64,7 +68,7 @@ def static_job( calc = pick_calculator( method, use_formation_energy=use_formation_energy, - formation_energy_kwargs=formation_energy_kwargs, + references=references, **calc_kwargs, ) final_atoms = Runner(atoms, calc).run_calc() @@ -81,7 +85,7 @@ def relax_job( opt_params: OptParams | None = None, additional_fields: dict[str, Any] | None = None, use_formation_energy: bool = False, - formation_energy_kwargs: Any = None, + references: Literal["MP20", "OMAT24"] | None = None, **calc_kwargs, ) -> OptSchema: """ @@ -102,12 +106,16 @@ def relax_job( Additional fields to add to the results dictionary. use_formation_energy If True, wrap the calculator with FormationEnergyCalculator to compute - formation energies. Currently only supported for FAIRChem with - method='fairchem' and task_name='omat'. Default is False. The formation - energy is returned in eV per formula unit (not eV/atom). - formation_energy_kwargs - Custom kwargs for the FormationEnergyCalculator wrapper. Only used if - use_formation_energy=True. Default is None. + formation energies. Requires fairchem-core package to be installed. + Supported for all methods. Default is False. The formation energy is + returned in eV per formula unit (not eV/atom). + references + Formation energy references to use. Only used if use_formation_energy=True. + Options: + - None: Use built-in references from FormationEnergyCalculator (FAIRChem models only) + - "OMAT24": Use OMAT24 references from https://huggingface.co/facebook/UMA + - "MP20": Use MP-20 references from matbench-discovery + Default is None. **calc_kwargs Custom kwargs for the underlying calculator. Set a value to `quacc.Remove` to remove a pre-existing key entirely. For a list of @@ -129,7 +137,7 @@ def relax_job( calc = pick_calculator( method, use_formation_energy=use_formation_energy, - formation_energy_kwargs=formation_energy_kwargs, + references=references, **calc_kwargs, ) diff --git a/src/quacc/recipes/mlp/references/2023-02-07-mp-elemental-reference-entries.json.gz b/src/quacc/recipes/mlp/references/2023-02-07-mp-elemental-reference-entries.json.gz new file mode 100644 index 0000000000000000000000000000000000000000..d5e66456ad763b01fca099ac4c79b1cc642d83f1 GIT binary patch literal 6350 zcmV;<7%}G`iwFoUicn+%|1vN#Gc7PREif@HZE!7RY-Md_ZggR6EplaMWpZV1V`VL6 zZgg^KWpgfSb8l_{++Asp97l5fD+7ME8j%r^SAR0e)rzttnh?2%WEe!A)Y!Z*hhbr{ z|J{)_^g`>+Xrl;J0W5+fq;Ae+Pv`51cyYY=%cJM*(fR5xkN)}HhkE~Z{P&}ezrIU% zZ`%7)dw+Lx-EL3i!@H07cdfqu^uVKkx_b1_|m@c1qtZm;jIKfGT~Tn;uT`0Mf|^0lw8>-c`}K7M74`Rbc*9zR;X>gJ}E zPhTY8?{DA#b*leye|z^%zI42dfBScN=3}}^@7i6vxm`~9_~EXko7cA=TN(f2FOTl7 z@7mjO#*4>)di~}zl`TCg5%a2`H-XhoiN(uJa<#jhYpYa zI*z-(ef`UQ&go{n?w;P>w(%_<9$8+&_|b0e-@m^5_2cr+j-P2fR?}TtUg3xTzOL^g#ZC@0=VU19dLFZz&*aX0{^mo#PNpz z25`*>?8RY3hm*s(o>>V=dfU~Os_xtVe z=1-1oluv=5;BB1CQ@B}Y5NIE~{r(EVg*gqVu@xdS-Au}1mKJp<923u@(dc_Lf{9l`Hz5mxTe*Bd-`@u3rCTS z+!jv&w;Vh-aRwmBSMCG3r*Gw*e+c5}?=a_Mju8y7D?$#MHO1B-@ee!3)bf{47QkH@ zceXTwIOjE=8ocMI_|#7qwHkEKx5J!sCl^h$afZ3)s*%H;4BYd%fy*)U3g~NXxj0{w z#Ld(@mXvb=J_6Xh{N#g~GK)tg4QzoxsOvm{Q;pklukdiMQ@nk*Ep4#Qil?ln-IzB$7j6t90UERTCGGo>oKd<%TAX z@H8x3w0=e%#>w(8rkL9aZ!hOWRm@tb*@YOZgwHY9m~(3!j7dUNL-^9+#$OR!p|Dgs zlQ|hRh$HrlN;WQj2GlmRZej;co+N%_JF{=g(N0PKdXTH#32<#5!#OXo0&ca%Zj3>X zu2D7Qq7SW)(&LwoIE+`uoXMSsQD7~GBMjauKK0?A)~Dk(%>TMm9LYEnXPA4cX5-Kr zXTx6ExOo&;AO+)m%~@iC-ZX<9*79z?wvFY|rdZ!sm2;N#_DCZ)w!LK_^kc%#SVyD~%!&5$sgBf019; zIo?7dHYEmepQe~n8l|3rOzmikbN&#eRB@x^&lXSKl3KgB5~3-dbX?SG$o=@srj=C3 z1b*5pvK%~zoZl&I`|%gmz%l0>osjB~TM&=Zed-uvA_v`Y%*UwR^I3kQ6)$8}z)^?_ zJ@xCOwikCB)=*F2jUl30+_oI7H2C69nUEKE^Tx~Mgq0>LF2z!iil>;p!|uekb>?^s zXNqTZXQUHDP@O3|BBf@J>vJJLZb%VKI9Uz>EX>-;my_pdH`>XU{&=t2HXqS;uAE%} z@2ZE?eW}5<)?-eg9)>79^oTgn@y4hS1zPV(U)1g3W>2GSFPJcED_;&)g14P4TU9I@ zMcXnZ@twV-Y_O+7KtdA?yY9>ZyeNz^7KmkYCu@!FMJ+3q4cB*D-fo%WJ+Wm^9Netk z@>aEL9Czxqyd8E8Vb=RLLbj4hGXY}DEprvDRc`>f*xV0b*Wyy77rSOo!Rw^dhN$#g za$H}a->@enoB{1gMvqkZwCj_T9l9P#6aiimw)K<%9|4*ysi!D8>KamX{EY5H@ zPG(dlPh;@76ZBpv{p7U}jm}qZ$fV}H4>m!msbpuW?T6?mhb7T#=@hi0f*BJGTB7&U zJ*_WcZ5YSJ6PbYIri9_K9>(o8?ELsvN%R6(a$GZ_O=!7UEIE;d-0GOyI>wNe6=HfG z{s_Hl>f|ZSuda8iKYZBGufj<_9nON8VQo2COT6vG8GiV%<);z}vBFQydZ@k3Tk_MTs|9%m zwdG*Njqb%5US@@#1}t(Z8L$^i%q`elO@!EEYt;NvW>3x1>=uJwFhwR!I|SCpY)>}C z8l21s2n^1~6~@6v0h@!Jreo$BZX?-2**_$mVu^f0Be`dMf zHy83B!e#AV>E1VUycO(5ylJ+*TB7KRW%Q106OJ8+Cpkp$HYy7v3E~{OM)yH zwVo}!-;yni`(c_|xaD9CbUS&}m+#dEOO6?Qw9Rnnk&K}PF1H{E9ODY+7!8&%dP0nq z*hkO1F={jQa!u=Vyl=LI8%OV4fN75RjhY9EA$r;gagP-fC&=na$Z4Za$r_@Nib*A! z)S?SH9>t~7V&Z~xIzRQOr$f+*lrfg3Vis|E6k+rOdSr53E*V~q|EI8o} zv+lR$WCd`0wZFZ*o_AVi%+QkSEjw(!)P@k95j#yebl>eU>R2??8wgMSf?$XNO})W< zM)TW`Th7z-DTD}C2*a$0T+fK?q`&oJQ;l1Uz+*zz)R0}Sz9n|XcAv$o+Wf^k(~*M1 zp`|)1r{;pRB{Nv9Y`R{3>*v#K%tlRI~rfColAUt7~2KoJk z#U6Xfk?U^MMn`Onr3i$r5XR7x^EMF{+Z-PwYFn!GqsR3^)D~Es{98{EzDb*EHW-5x z7-p7jIa!g1=Dq4zNlWq?3e-AgpSd(>##1!Km70a6c59BH*Bgh#8{-)KV>Rg(ar886 zBw>9a>)ZF+b|6`>#?0Wh987v43wDCsw`yG>2XAdit#d|5R5$^I-l+xet$}ciT0M?h z{g6O%TIrO*1j83fS+E{@JdPL%SlKS(8t`$$# zL=GOuU5;Nf;gFy8vc2Ca=lk<}rG`qjHS|4jpyVj$p{QDrTq35_=@=C)@vVFYI)w|< z_2OO?uIksPeE+y*3*6W|A|`H@g!`lJRCyDJoBc@!KA5?(_Z@buik3R!3P_m=NwZl z=mcW5CoCf|RT^64W2`+DxWO`c8bBqh(FCvRk<({L;6Hx5smvf->us1dUM>e~*|?oD zzmMNu0Z&!9dX0%)YQ+TGIaN%hLdlI&!=8?@;vIK9|;_Ag&e?%99%iW2inlQ(z1@90PBQq$T(h>H8LVEKF9S zxAmI=pR_F(D+i#b4|T1v zYheiXT6G?9)U}*~wy?g}@`o+ysk0{+12gfAE-Xi5oEPnPGH*Y;R|!K;7JG=MBXQ-J zGnZ^ziY}!VdJG4sHP|u3FsfTY0&UH_1;b<%a!TvX+p7(`-;ho&IQ9-^1E|Z%r0m+= zi7Z^}{1md7v=nl8p&MwvnZ_QC@CPGNaymvXgvU-m4kr_Ii6jKC?%G`sw=FBFX6!_l zUvB13my?wl)Sa@uuaXM4k#K6sXP=`|Y&v4Th6*;qENl>lMvbTF`bn^Jk2mmQ7od#I8jO(+;?o`k0K{YwWZvymV&TYN6n| z1)C}qCJHurZ>+%1n3BCcN)|-mwltEP3t13N%fWrd4B*EtlWsID87sA zE6e-*qtZ(bB@<+CM*g#K%|n>eLdFnM;Zl!*H@DQWqQuwwmeXK@rolC?X?>-B%f=9Y zf@m2G&Vsj}^~MlGJJqp%zEMqFz$9i(98K@Gg|XwbiiH#7Mh%SdaDcs&i-hWEyc`mP znu#M2%b`JzK10)T!_j(k5`v9GjC7WHdvk#-Sm*cZy8Pjd+E|`*>8=M;3UMNk_-*uv zR3~oL2hZ*h{M3v}){x{vl`#al59c+;u-w)9oXDmdv)G8S2vZ5TSNg`Rz2qET{h~0` zmVE4lDDQYru0!t0bIF*Db+z>#g#k7aW9nRy4_STEwHVn6pNKPI3J@iUBa5=kI`!>fE(MK zzkf}NiDO6+a?0Hxat@*Q z%!P)cYN4T@e2{)5#+C|ST4@hDYPhBJ=UQ*@z<8uYD3~ z)aF?!bvap|5;SHfTle*iYU@&RHUS6k$@m@$6&tD6mQqOO5L-WR8J@vqbK8Z1Est6+ z+^4Hr&mlf+i#zs+&dn+h%gLI(b9SfcnP(p~)XA$3LV+A7^ey<(eG{y4t@1Es?~g%R z2V*_MSXZmo9w^FfXx8r2Q>!rd)3!5$F|r|=g>*k%vUFlxcEa3GY9A&yL6>jdmuk72 z7_vlma$D#_mW!yz7)^C&6PKtCE=c=EwafhBs#c@!cfW3G=_Do~2Xzlf1!oh7%%<7yyo;5L{OQO(MocZo!*rVuf!~vj=2Ex0R(`PS zW4W*IH=U(H>j9?QQNPv&;vACQYemD?_j9~ecCG@uR6CmPB&KUEI(hdwwuCMmMZKJP zL~_oA3)wkM+xOv~);BI~nJy6y?wOeFES6hNo@*(&z0_7mrerSnU>kZKN~dIZash3wRfQxPP zDv3``!8(U7w~CQN66j@#z z3oqVnYux~ii#QXCd!hT%;D!k3PKZstt8?XIsN{|^CLaM2o*kHAoQ_)}O` z-214{r=qvoeCl#j>lM4&wliGL86Rkt0Dh{EL|f-+FV^t1U16Bj^gAVp9ea-wIfrIL z%bcSGG%XI*hhu0H!pQ#`e_EsowX6t4R`&?fXRwB!HWYo}L4X2sm?nVpMOhoSQ~UW( z`88?@U$dSMK6p$7+y6h_K=k&V@4$dFBg-E;zNk zDi^iBcPVWe(Rm)-Ry^Ym-|I@#xRYAk_emkEnTi;*YC`t~G6Q>)g7L%8W(@u)&A=l8 zEH~iPeGTM;QQ(H_rQ#lMIYQ`8{GjPMmF+4Q*aNz~)K?!DrMl%LWD~Ad5gOIzn(Nh5 zj26g8Q_W%aLC2v@$ohT|0Izk*J>1iJ?vS>WUj?u>#xN^5Jkg zOhdy3SnQn{O3N*nObLBGQoSLbEa~+I<5^kI`x&j7-*3CQ5}kKp)_eZ__2tIQy@po5 zzaEG~ny3?w(rHc^45vV&Nz+-b-b7bID?J^9HjM0V7(O*dP!R{}3i3~PwZ6o4vnA}n z1>t)%T_||5G1G(JsqXMnA6_lVC2_8?7=#2Vd$2;iDY)caZ}tEyLx%3J1afeOOM_jJ zEEv-JmE^A07rQoet`ttA+)HF_%i>LiR%1-Tn-Wbg zCACQG<>l4zg^S)i2*|msA;xNlKBiSWY3R^3Rx+|_!K%q?l_DVh-de22r63pz`C zjXylm<2s4uyu>t=dUjNMYQz+~e2}>34qv#+`ESPxg6C zJ#^PbVSlznHD3@T@zL3xQYAlT6>}-oVw|(+ZN}70YZYqgkxH~ZmSgm0j*eh6W*<;p zrAl5tYwDip<*rua?z>Ipy%2o#VQS~b!DEh68M%qD6Nmjyg&noPJ-AA~BTy0p2c|Z1 zsaxOp5F32|cht*9Ar#ua8sZS~!pO}hZR&MdliLk(XHU_|ui&SfRc|i_c5L?A>+tk; zLEfxc6QawOlCJ?u0SGR7mrG3+swLO#0rYU 0 + print(f"\nLoaded MP-20 references for {len(mp20_refs)} elements") + + # Test with FAIRChem using MP20 references (explicit) + if "fairchem" in methods: + _set_dtype(32) + atoms = bulk("Cu") + + # Test with MP-20 references + output = static_job( + atoms, + method="fairchem", + name_or_path="uma-s-1", + task_name="omat", + use_formation_energy=True, + references="MP20", + ) + energy_per_atom = output["results"]["energy"] / len(atoms) + print(f"Cu formation energy with MP-20 refs: {energy_per_atom:.4f} eV/atom") + + # Should be close to zero for elemental system + assert abs(energy_per_atom) < 0.2 + + except Exception as e: + pytest.skip(f"Could not load MP-20 references: {e}") From f0053a13220189eaebc21e740890a9b6b936b5b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Jan 2026 20:32:01 +0000 Subject: [PATCH 05/10] pre-commit auto-fixes --- src/quacc/recipes/mlp/_base.py | 46 ++++++----- .../recipes/mlp_recipes/test_core_recipes.py | 76 +++++++------------ 2 files changed, 53 insertions(+), 69 deletions(-) diff --git a/src/quacc/recipes/mlp/_base.py b/src/quacc/recipes/mlp/_base.py index f91467ee27..661d1af545 100644 --- a/src/quacc/recipes/mlp/_base.py +++ b/src/quacc/recipes/mlp/_base.py @@ -5,7 +5,7 @@ from functools import lru_cache, wraps from importlib.util import find_spec from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from ase.units import GPa as _GPa_to_eV_per_A3 from monty.dev import requires @@ -55,35 +55,35 @@ def wrapped(*args, **kwargs): def _get_omat24_references() -> dict[str, float]: """ Fetch formation energy references for OMAT24-trained models from HuggingFace. - + These references come from https://huggingface.co/facebook/UMA/blob/main/references/form_elem_refs.yaml - + Returns ------- dict[str, float] Dictionary mapping element symbols to reference energies (eV/atom). """ - from huggingface_hub import hf_hub_download import yaml + from huggingface_hub import hf_hub_download LOGGER.info("Downloading OMAT24 formation energy references from HuggingFace...") - + # Download the form_elem_refs.yaml file from HuggingFace refs_file = hf_hub_download( repo_id="facebook/UMA", filename="references/form_elem_refs.yaml", - repo_type="model" + repo_type="model", ) - + # Load and extract the omat references with open(refs_file) as f: refs_data = yaml.safe_load(f) - + omat_refs = refs_data.get("refs", {}).get("omat", {}) - + if not omat_refs: raise ValueError("Could not find 'refs.omat' in the downloaded reference file.") - + LOGGER.info(f"Loaded OMAT24 references for {len(omat_refs)} elements.") return omat_refs @@ -92,10 +92,10 @@ def _get_omat24_references() -> dict[str, float]: def _get_mp20_references() -> dict[str, float]: """ Load formation energy references for MP-20 compatible models. - + These references come from matbench-discovery repository: https://github.com/janosh/matbench-discovery - + Returns ------- dict[str, float] @@ -106,27 +106,33 @@ def _get_mp20_references() -> dict[str, float]: from pathlib import Path LOGGER.info("Loading MP-20 formation energy references from local file...") - + # Load from local gzipped JSON file - refs_file = Path(__file__).parent / "references" / "2023-02-07-mp-elemental-reference-entries.json.gz" - + refs_file = ( + Path(__file__).parent + / "references" + / "2023-02-07-mp-elemental-reference-entries.json.gz" + ) + if not refs_file.exists(): raise FileNotFoundError( f"MP-20 reference file not found at {refs_file}. " "Please ensure the file is in src/quacc/recipes/mlp/references/" ) - + # Load the gzipped JSON file with gzip.open(refs_file, "rt") as f: refs_data = json.load(f) - + # Extract element references based on the expected structure # The file should contain element references if isinstance(refs_data, dict): mp20_refs = refs_data else: - raise ValueError(f"Unexpected format in MP-20 reference file: {type(refs_data)}") - + raise ValueError( + f"Unexpected format in MP-20 reference file: {type(refs_data)}" + ) + LOGGER.info(f"Loaded MP-20 references for {len(mp20_refs)} elements.") return mp20_refs @@ -238,7 +244,7 @@ def pick_calculator( # Determine which reference energies to use fe_kwargs = {} - + if references == "OMAT24": # Use OMAT24 references from HuggingFace fe_kwargs["references"] = _get_omat24_references() diff --git a/tests/core/recipes/mlp_recipes/test_core_recipes.py b/tests/core/recipes/mlp_recipes/test_core_recipes.py index 0842e7d2de..ddf8db7dcc 100644 --- a/tests/core/recipes/mlp_recipes/test_core_recipes.py +++ b/tests/core/recipes/mlp_recipes/test_core_recipes.py @@ -319,12 +319,9 @@ def test_formation_energy_with_mace(tmp_path, monkeypatch): # Formation energy should work with mace-mp when references are provided atoms = bulk("Cu") # Use OMAT24 references for testing - + output = static_job( - atoms, - method="mace-mp", - use_formation_energy=True, - references="OMAT24", + atoms, method="mace-mp", use_formation_energy=True, references="OMAT24" ) # Should complete successfully assert "energy" in output["results"] @@ -369,7 +366,6 @@ def test_builtin_vs_omat24_references(tmp_path, monkeypatch): f"Built-in references ({energy_builtin:.6f} eV) differ from OMAT24 " f"references ({energy_omat24:.6f} eV) by {abs(energy_builtin - energy_omat24):.6e} eV" ) - print(f"\nBuilt-in: {energy_builtin:.6f} eV, OMAT24: {energy_omat24:.6f} eV") @pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") @@ -400,31 +396,31 @@ def test_static_job_formation_energy_with_references(tmp_path, monkeypatch): @pytest.mark.skipif(find_spec("fairchem") is None, reason="fairchem not installed") def test_formation_energy_consistency_across_models(tmp_path, monkeypatch): """Test that formation energies are consistent across different models. - + This test ensures that when different ML potentials calculate formation - energies with appropriate references, they yield similar results + energies with appropriate references, they yield similar results (within ~0.1 eV/atom). This helps debug issues with formation energy references and MP corrections. """ monkeypatch.chdir(tmp_path) - + from huggingface_hub.utils._auth import get_token if not get_token(): pytest.skip("HuggingFace token not available for FAIRChem") - + # Get OMAT24 references for non-FAIRChem models # (FAIRChem will use built-in references with None) - + # Test elemental system (Cu) - formation energy should be near zero test_structures = [ ("Cu", bulk("Cu"), 0.0, 0.2), # (name, structure, expected_per_atom, tolerance) ("MgO", bulk("MgO", crystalstructure="rocksalt", a=4.2), -3.0, 0.5), ] - + for struct_name, atoms, expected_per_atom, tolerance in test_structures: results = {} - + # Test with FAIRChem (uses built-in references with None) if "fairchem" in methods: _set_dtype(32) @@ -438,21 +434,16 @@ def test_formation_energy_consistency_across_models(tmp_path, monkeypatch): ) energy_per_atom = output["results"]["energy"] / len(atoms) results["fairchem"] = energy_per_atom - print(f"\n{struct_name} - fairchem: {energy_per_atom:.4f} eV/atom") - + # Test with MACE (using OMAT24 references) if "mace-mp" in methods: _set_dtype(64) output = static_job( - atoms, - method="mace-mp", - use_formation_energy=True, - references="OMAT24", + atoms, method="mace-mp", use_formation_energy=True, references="OMAT24" ) energy_per_atom = output["results"]["energy"] / len(atoms) results["mace-mp"] = energy_per_atom - print(f"{struct_name} - mace-mp: {energy_per_atom:.4f} eV/atom") - + # Test with TensorNet (using OMAT24 references) if "tensornet" in methods: _set_dtype(32) @@ -464,53 +455,42 @@ def test_formation_energy_consistency_across_models(tmp_path, monkeypatch): ) energy_per_atom = output["results"]["energy"] / len(atoms) results["tensornet"] = energy_per_atom - print(f"{struct_name} - tensornet: {energy_per_atom:.4f} eV/atom") - + # Test with SevenNet (using OMAT24 references) if "sevennet" in methods: _set_dtype(32) output = static_job( - atoms, - method="sevennet", - use_formation_energy=True, - references="OMAT24", + atoms, method="sevennet", use_formation_energy=True, references="OMAT24" ) energy_per_atom = output["results"]["energy"] / len(atoms) results["sevennet"] = energy_per_atom - print(f"{struct_name} - sevennet: {energy_per_atom:.4f} eV/atom") - + # Test with ORB (using OMAT24 references) if "orb" in methods: _set_dtype(32) output = static_job( - atoms, - method="orb", - use_formation_energy=True, - references="OMAT24", + atoms, method="orb", use_formation_energy=True, references="OMAT24" ) energy_per_atom = output["results"]["energy"] / len(atoms) results["orb"] = energy_per_atom - print(f"{struct_name} - orb: {energy_per_atom:.4f} eV/atom") - + # Verify we have at least 2 results to compare if len(results) < 2: pytest.skip(f"Not enough models available to compare for {struct_name}") - + # Check that all results are within expected range for method_name, energy in results.items(): assert abs(energy - expected_per_atom) < tolerance, ( f"{struct_name} - {method_name}: {energy:.4f} eV/atom is outside " f"expected range {expected_per_atom} +/- {tolerance} eV/atom" ) - + # Check consistency across models (within 0.1 eV/atom) energies_list = list(results.values()) max_energy = max(energies_list) min_energy = min(energies_list) spread = max_energy - min_energy - - print(f"{struct_name} - Energy spread: {spread:.4f} eV/atom (min: {min_energy:.4f}, max: {max_energy:.4f})") - + # Allow slightly larger spread for compounds than elements max_spread = 0.15 if "Cu" in struct_name else 0.2 assert spread < max_spread, ( @@ -523,26 +503,25 @@ def test_formation_energy_consistency_across_models(tmp_path, monkeypatch): def test_formation_energy_mp20_references(tmp_path, monkeypatch): """Test that MP-20 references can be loaded and used.""" monkeypatch.chdir(tmp_path) - + from huggingface_hub.utils._auth import get_token if not get_token(): pytest.skip("HuggingFace token not available for FAIRChem") - + # Test loading MP-20 references from quacc.recipes.mlp._base import _get_mp20_references - + try: mp20_refs = _get_mp20_references() assert isinstance(mp20_refs, dict) assert len(mp20_refs) > 0 - print(f"\nLoaded MP-20 references for {len(mp20_refs)} elements") - + # Test with FAIRChem using MP20 references (explicit) if "fairchem" in methods: _set_dtype(32) atoms = bulk("Cu") - + # Test with MP-20 references output = static_job( atoms, @@ -553,10 +532,9 @@ def test_formation_energy_mp20_references(tmp_path, monkeypatch): references="MP20", ) energy_per_atom = output["results"]["energy"] / len(atoms) - print(f"Cu formation energy with MP-20 refs: {energy_per_atom:.4f} eV/atom") - + # Should be close to zero for elemental system assert abs(energy_per_atom) < 0.2 - + except Exception as e: pytest.skip(f"Could not load MP-20 references: {e}") From 99208ff166f7249319e85788741f4a9550767db8 Mon Sep 17 00:00:00 2001 From: Zack Ulissi Date: Wed, 28 Jan 2026 20:42:42 +0000 Subject: [PATCH 06/10] update orb expect values from test runs in quacc repo --- tests/core/recipes/mlp_recipes/test_core_recipes.py | 6 +++--- tests/core/recipes/mlp_recipes/test_elastic_recipes.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/core/recipes/mlp_recipes/test_core_recipes.py b/tests/core/recipes/mlp_recipes/test_core_recipes.py index 0842e7d2de..e37fbd0c0d 100644 --- a/tests/core/recipes/mlp_recipes/test_core_recipes.py +++ b/tests/core/recipes/mlp_recipes/test_core_recipes.py @@ -56,7 +56,7 @@ def test_static_job(tmp_path, monkeypatch, method): "tensornet": -3.7593491077423096, "mace-mp": -4.097862720291976, "sevennet": -4.096191883087158, - "orb": -4.093477725982666, + "orb": -3.7420763969421387, "fairchem": -3.7579006783217954, } atoms = bulk("Cu") @@ -95,7 +95,7 @@ def test_relax_job(tmp_path, monkeypatch, method): "mace-mp": -32.78264569638644, "tensornet": -30.074462890625, "sevennet": -32.76924133300781, - "orb": -32.7361946105957, + "orb": -29.93630599975586, "fairchem": -30.004380887389797, } @@ -143,7 +143,7 @@ def test_relax_cell_job(tmp_path, monkeypatch, method): "mace-mp": -32.8069374165035, "tensornet": -30.079431533813477, "sevennet": -32.76963806152344, - "orb": -32.73428726196289, + "orb": -29.93630599975586, "fairchem": -30.005004590392726, } diff --git a/tests/core/recipes/mlp_recipes/test_elastic_recipes.py b/tests/core/recipes/mlp_recipes/test_elastic_recipes.py index 18ba46be05..d85b5639a0 100644 --- a/tests/core/recipes/mlp_recipes/test_elastic_recipes.py +++ b/tests/core/recipes/mlp_recipes/test_elastic_recipes.py @@ -56,7 +56,7 @@ def test_elastic_jobs(tmp_path, monkeypatch, method): "tensornet": 138.172, "mace-mp": 130.727, "sevennet": 142.296, - "orb": 190.195, + "orb": 135.25, "fairchem": 151.367, } From a5607d2c4c01d48e36f5c8ee45b39931a6137063 Mon Sep 17 00:00:00 2001 From: Zack Ulissi Date: Wed, 28 Jan 2026 20:45:16 +0000 Subject: [PATCH 07/10] pre-commit checks --- src/quacc/recipes/mlp/_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/quacc/recipes/mlp/_base.py b/src/quacc/recipes/mlp/_base.py index 661d1af545..3a370f27c1 100644 --- a/src/quacc/recipes/mlp/_base.py +++ b/src/quacc/recipes/mlp/_base.py @@ -2,6 +2,8 @@ from __future__ import annotations +import Path + from functools import lru_cache, wraps from importlib.util import find_spec from logging import getLogger @@ -76,7 +78,7 @@ def _get_omat24_references() -> dict[str, float]: ) # Load and extract the omat references - with open(refs_file) as f: + with Path.open(refs_file) as f: refs_data = yaml.safe_load(f) omat_refs = refs_data.get("refs", {}).get("omat", {}) From 39683beeefc66c592d8d3211d626d53f15f37003 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Jan 2026 20:45:27 +0000 Subject: [PATCH 08/10] pre-commit auto-fixes --- src/quacc/recipes/mlp/_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/quacc/recipes/mlp/_base.py b/src/quacc/recipes/mlp/_base.py index 3a370f27c1..777921c6e3 100644 --- a/src/quacc/recipes/mlp/_base.py +++ b/src/quacc/recipes/mlp/_base.py @@ -2,13 +2,12 @@ from __future__ import annotations -import Path - from functools import lru_cache, wraps from importlib.util import find_spec from logging import getLogger from typing import TYPE_CHECKING +import Path from ase.units import GPa as _GPa_to_eV_per_A3 from monty.dev import requires From 09a434038853c76ab7c3cd7d7d9906f36254d578 Mon Sep 17 00:00:00 2001 From: Zack Ulissi Date: Wed, 28 Jan 2026 20:45:43 +0000 Subject: [PATCH 09/10] pre-commit --- src/quacc/recipes/mlp/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/quacc/recipes/mlp/_base.py b/src/quacc/recipes/mlp/_base.py index 3a370f27c1..771fc938fa 100644 --- a/src/quacc/recipes/mlp/_base.py +++ b/src/quacc/recipes/mlp/_base.py @@ -2,7 +2,7 @@ from __future__ import annotations -import Path +from pathlib import Path from functools import lru_cache, wraps from importlib.util import find_spec From 64c104e26afebe7fc37cda616da1ec546bd9c6a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Jan 2026 20:46:54 +0000 Subject: [PATCH 10/10] pre-commit auto-fixes --- src/quacc/recipes/mlp/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/quacc/recipes/mlp/_base.py b/src/quacc/recipes/mlp/_base.py index 2d80582246..990d95cc11 100644 --- a/src/quacc/recipes/mlp/_base.py +++ b/src/quacc/recipes/mlp/_base.py @@ -5,9 +5,9 @@ from functools import lru_cache, wraps from importlib.util import find_spec from logging import getLogger +from pathlib import Path from typing import TYPE_CHECKING -from pathlib import Path from ase.units import GPa as _GPa_to_eV_per_A3 from monty.dev import requires