Skip to content

Commit a686eda

Browse files
committed
BaseOutput: Add SubMapping for nested output namespaces
A code's output namespace can grow to dozens or even hundreds of fields once every quantity it produces is exposed. Flattening all of them onto a single top-level mapping floods autocomplete, makes the API hard to skim, and forces unrelated quantities (forces, bands, occupations, magnetization, ...) to share one namespace with no grouping. Related fields belong together under a named sub-namespace so users can discover them by category instead of scrolling through one giant list. `SubMapping` is a sentinel default, paired with the existing `@output_mapping` decorator, that lets a mapping class declare a nested mapping as a regular field: @output_mapping class _PwMapping: magnetization: _MagnetizationMapping The decorator walks `get_type_hints` and injects a `SubMapping(mapping_cls=...)` default for any bare annotation whose type is itself an `@output_mapping` class. At runtime, `BaseOutput.outputs` recurses into each `SubMapping` field and builds the nested mapping from the matching slice of the parsed output dict, so `out.outputs.magnetization.total` works exactly like a top-level field — same lazy resolution, same `AttributeError` on missing data, same frozen-dataclass features. The constructor of `BaseOutput` now creates a (potentially) 2 level dictionary that still maps the output names to their corresponding `Spec` and stores that in `_output_spec_mapping`. The guard against non-`Spec` fields is still present for the top level and sub mapping. The `get_output` method is also updated. Currently, only top-level outputs are possible, and these return a dictionary with all the available outputs in the sub mapping. Conversion is also still possible for nested outputs via dot-separation, e.g. to convert the "total" sub output in "magnetization" you have to define a converter for "magnetization.total".
1 parent 2629559 commit a686eda

File tree

2 files changed

+179
-15
lines changed

2 files changed

+179
-15
lines changed

src/qe_tools/outputs/base.py

Lines changed: 93 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import abc
6+
import contextlib
67
import dataclasses
78
import typing
89
from functools import cached_property
@@ -16,34 +17,73 @@
1617
T = typing.TypeVar("T")
1718

1819

20+
class SubMapping:
21+
"""Sentinel marking a field as a nested output mapping.
22+
23+
`BaseOutput` resolves these at instantiation time. Nesting is intended to
24+
be one level only: a sub-mapping class should only contain `Spec` fields.
25+
"""
26+
27+
def __init__(self, mapping_cls: type):
28+
self.mapping_cls = mapping_cls
29+
30+
1931
def output_mapping(cls):
2032
"""Decorator that defines a typed, frozen output mapping for a Quantum ESPRESSO code.
2133
2234
Applies `@dataclass(frozen=True)` and injects `__getattribute__` and `__dir__` so that:
2335
24-
- Accessing a field whose value is still a `Spec` raises `AttributeError` with a clear
25-
message (i.e. the output was not parsed).
36+
- Accessing a field whose value is still a `Spec` or `SubMapping` raises
37+
`AttributeError` with a clear message (i.e. the output was not parsed).
2638
- `dir()` only lists fields that were successfully extracted.
2739
28-
Each field must declare a `Spec(...)` as its default value:
40+
Output fields declare a `Spec(...)` default. Sub-namespace fields are
41+
declared with a bare annotation whose type is another `@output_mapping`
42+
class — the decorator auto-injects a `SubMapping(hint)` default:
2943
3044
fermi_energy: float = Spec("path.to.fermi_energy")
3145
\"""Fermi energy in eV.\"""
46+
magnetization: _MagnetizationMapping
47+
\"""Nested magnetization outputs.\"""
3248
"""
3349

3450
def __getattribute__(self, name):
3551
value = object.__getattribute__(self, name)
36-
if isinstance(value, Spec):
52+
if isinstance(value, (Spec, SubMapping)):
3753
raise AttributeError(f"'{name}' is not available in the parsed outputs.")
3854
return value
3955

4056
def __dir__(self):
4157
return [
42-
name for name, value in self.__dict__.items() if not isinstance(value, Spec)
58+
name
59+
for name, value in self.__dict__.items()
60+
if not isinstance(value, (Spec, SubMapping))
4361
]
4462

4563
cls.__getattribute__ = __getattribute__
4664
cls.__dir__ = __dir__
65+
66+
# Inject `SubMapping(hint)` defaults for bare annotations whose type is
67+
# itself an `@output_mapping`-decorated class. Note: `get_type_hints`
68+
# evaluates annotations, which is fine as long as the module does not use
69+
# `from __future__ import annotations` together with `TYPE_CHECKING`-only
70+
# sub-mapping imports — in that case the eval would raise `NameError` here.
71+
# Sub-mapping classes must be defined *before* the parent that references
72+
# them (Python enforces this anyway for non-future-annotations modules).
73+
for name, hint in typing.get_type_hints(cls).items():
74+
if hasattr(cls, name): # already has a default
75+
continue
76+
77+
if not (isinstance(hint, type) and getattr(hint, "_is_output_mapping", False)):
78+
raise TypeError(
79+
f"{cls.__name__}.{name}: needs a Spec(...) default, or a bare "
80+
f"annotation whose type is an @output_mapping class "
81+
f"(which must be defined before this class)"
82+
)
83+
84+
setattr(cls, name, SubMapping(hint))
85+
86+
cls._is_output_mapping = True
4787
return dataclasses.dataclass(frozen=True)(cls)
4888

4989

@@ -68,15 +108,25 @@ def _get_mapping_class(cls) -> type:
68108

69109
def __init__(self, raw_outputs: dict):
70110
self.raw_outputs = raw_outputs
71-
self._output_spec_mapping = {}
72111

73-
for field in dataclasses.fields(self._get_mapping_class()):
74-
if not isinstance(field.default, Spec):
75-
raise TypeError(
76-
f"{type(self).__name__}.{field.name}: expected a Spec(...) default, "
77-
f"got {field.default!r}"
78-
)
79-
self._output_spec_mapping[field.name] = field.default
112+
def build(mapping_cls: type) -> dict:
113+
"""Build the nested spec dict from a mapping class."""
114+
result: dict = {}
115+
116+
for field in dataclasses.fields(mapping_cls):
117+
if isinstance(field.default, SubMapping):
118+
result[field.name] = build(field.default.mapping_cls)
119+
elif isinstance(field.default, Spec):
120+
result[field.name] = field.default
121+
else:
122+
raise TypeError(
123+
f"{mapping_cls.__name__}.{field.name}: expected a Spec(...) or "
124+
f"SubMapping(...) default, got {field.default!r}"
125+
)
126+
127+
return result
128+
129+
self._output_spec_mapping = build(self._get_mapping_class())
80130

81131
@classmethod
82132
@abc.abstractmethod
@@ -109,7 +159,16 @@ def get_output(
109159
>>> pw_out.get_output(name="structure")
110160
>>> pw_out.get_output(name="structure", to="pymatgen")
111161
"""
112-
output_data = glom(self.raw_outputs, self._output_spec_mapping[name])
162+
entry = self._output_spec_mapping[name]
163+
164+
if isinstance(entry, dict):
165+
output_data: typing.Any = {}
166+
167+
for sub_name, sub_spec in entry.items():
168+
with contextlib.suppress(GlomError):
169+
output_data[sub_name] = glom(self.raw_outputs, sub_spec)
170+
else:
171+
output_data = glom(self.raw_outputs, entry)
113172

114173
if to is None:
115174
return output_data
@@ -126,6 +185,14 @@ def get_output(
126185
else:
127186
raise ValueError(f"Library '{to}' is not supported.")
128187

188+
if isinstance(entry, dict):
189+
return {
190+
sub_name: Converter().convert(f"{name}.{sub_name}", sub_value)
191+
if sub_name in Converter.conversion_mapping
192+
else sub_value
193+
for sub_name, sub_value in output_data.items()
194+
}
195+
129196
return (
130197
Converter().convert(name, output_data)
131198
if name in Converter.conversion_mapping
@@ -180,4 +247,15 @@ def list_outputs(self, only_available: bool = True) -> list[str]:
180247
@cached_property
181248
def outputs(self) -> T:
182249
"""Namespace with available outputs."""
183-
return self._get_mapping_class()(**self.get_output_dict())
250+
251+
def build(mapping_cls: type, data: dict):
252+
defaults = {f.name: f.default for f in dataclasses.fields(mapping_cls)}
253+
kwargs = {
254+
name: build(defaults[name].mapping_cls, value) # type: ignore[union-attr]
255+
if isinstance(defaults[name], SubMapping)
256+
else value
257+
for name, value in data.items()
258+
}
259+
return mapping_cls(**kwargs)
260+
261+
return build(self._get_mapping_class(), self.get_output_dict())

tests/outputs/test_base.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,89 @@ def test_get_output_dict(raw_outputs):
7878
"B",
7979
]
8080
)
81+
82+
83+
# --- SubMapping (nested output namespaces) ----------------------------------
84+
85+
86+
@output_mapping
87+
class _NestedMapping:
88+
c: int = Spec("b.c")
89+
d: int = Spec("b.d")
90+
missing: int = Spec("b.nope")
91+
92+
93+
@output_mapping
94+
class _ParentMapping:
95+
A: float = Spec("a")
96+
nested: _NestedMapping
97+
98+
99+
class _ParentOutput(BaseOutput[_ParentMapping]):
100+
@classmethod
101+
def from_dir(cls, _: str):
102+
pass
103+
104+
105+
def test_submapping_output_access(raw_outputs):
106+
"""Resolved outputs on a sub-namespace are accessible via attribute."""
107+
outputs = _ParentOutput(raw_outputs).outputs
108+
assert outputs.nested.c == 3
109+
assert outputs.nested.d == 4
110+
111+
112+
def test_submapping_missing_output_raises(raw_outputs):
113+
outputs = _ParentOutput(raw_outputs).outputs
114+
with pytest.raises(AttributeError, match="missing.*not available"):
115+
outputs.nested.missing
116+
117+
118+
def test_submapping_list_outputs_top_level_only(raw_outputs):
119+
"""`list_outputs` yields top-level field names only — sub-namespaces as a single entry."""
120+
pw_out = _ParentOutput(raw_outputs)
121+
# Both modes return the same result here: sub-namespaces are *always* listed
122+
# (even if every output is missing), and there are no unresolvable top-level
123+
# outputs in `_ParentMapping` for `only_available=True` to filter out.
124+
assert pw_out.list_outputs() == ["A", "nested"]
125+
assert pw_out.list_outputs(only_available=False) == ["A", "nested"]
126+
127+
128+
def test_submapping_get_output_namespace_returns_dict(raw_outputs):
129+
"""`get_output(<sub-namespace>)` returns a partial dict of available outputs."""
130+
pw_out = _ParentOutput(raw_outputs)
131+
assert pw_out.get_output("nested") == {"c": 3, "d": 4}
132+
# Users index the dict directly.
133+
assert pw_out.get_output("nested")["c"] == 3
134+
135+
136+
def test_submapping_get_output_dict_shape(raw_outputs):
137+
"""`get_output_dict()` is flat at the top level; sub-namespaces are nested dicts."""
138+
assert _ParentOutput(raw_outputs).get_output_dict() == {
139+
"A": 1,
140+
"nested": {"c": 3, "d": 4},
141+
}
142+
143+
144+
def test_validation_rejects_non_spec_top_level_default():
145+
"""Top-level field defaults that aren't `Spec`/`SubMapping` raise in `__init__`."""
146+
147+
@output_mapping
148+
class _BadParent:
149+
bad: int = 42 # type: ignore[assignment]
150+
151+
class _BadOutput(BaseOutput[_BadParent]):
152+
@classmethod
153+
def from_dir(cls, _: str):
154+
pass
155+
156+
with pytest.raises(TypeError, match="_BadParent.bad"):
157+
_BadOutput(raw_outputs={})
158+
159+
160+
def test_decorator_rejects_bare_annotation_non_output_mapping():
161+
"""Bare annotation whose type isn't `@output_mapping`-decorated is rejected at decoration time."""
162+
with pytest.raises(TypeError, match="bad.*@output_mapping class"):
163+
164+
@output_mapping
165+
class _BadBare:
166+
bad: int

0 commit comments

Comments
 (0)