Skip to content

Commit 304bf72

Browse files
committed
Format mixed model formula
1 parent dd836aa commit 304bf72

2 files changed

Lines changed: 33 additions & 10 deletions

File tree

causalpy/formula.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from dataclasses import dataclass
15-
from typing import Any, Callable, Mapping, Optional
16-
14+
import re
1715
from collections import OrderedDict
16+
from collections.abc import Callable, Mapping
1817
from dataclasses import dataclass
19-
from formulaic import Formula
18+
from typing import Any
2019

21-
import pandas as pd
2220
import numpy as np
23-
import re
21+
import pandas as pd
22+
from formulaic import Formula
2423

2524
_RE_RANDOM_COMPONENT = re.compile(r"\(\s*([^|()]+?)\s*\|\s*([^|()]+?)\s*\)")
2625

@@ -77,11 +76,13 @@ def _elements(self) -> list[str]:
7776

7877
@property
7978
def has_intercept(self) -> bool:
79+
"""Whether this component includes a random intercept."""
8080
elements = set(self._elements)
8181
return "0" not in elements and "-1" not in elements
8282

8383
@property
8484
def slopes(self) -> list[str]:
85+
"""Return non-intercept random-slope elements."""
8586
return [
8687
element for element in self._elements if element not in ("0", "-1", "1")
8788
]
@@ -123,22 +124,27 @@ class MixedModelMatrices:
123124

124125
@property
125126
def y(self) -> pd.DataFrame:
127+
"""Alias for ``lhs`` outcome matrix."""
126128
return self.lhs
127129

128130
@property
129131
def X(self) -> pd.DataFrame:
132+
"""Alias for ``rhs`` fixed-effects design matrix."""
130133
return self.rhs
131134

132135
@property
133136
def model_spec(self):
137+
"""Fixed-effects Formulaic model specification used for ``rhs``."""
134138
return self.metadata["model_spec"]
135139

136140
@property
137141
def fixed_model_spec(self):
142+
"""Explicit fixed-effects model specification alias."""
138143
return self.metadata["fixed_model_spec"]
139144

140145
@property
141146
def random_model_spec(self):
147+
"""Random-effects model specification used for ``Z``."""
142148
return self.metadata["random_model_spec"]
143149

144150

@@ -175,10 +181,12 @@ class MixedModelFormula:
175181

176182
@property
177183
def has_random_effects(self) -> bool:
184+
"""Whether the parsed formula contains random components."""
178185
return len(self.random_components) > 0
179186

180187
@property
181188
def grouping_variables(self) -> list[str]:
189+
"""Unique grouping-variable names in declaration order."""
182190
return list(
183191
OrderedDict.fromkeys(
184192
component.grouping for component in self.random_components
@@ -187,10 +195,12 @@ def grouping_variables(self) -> list[str]:
187195

188196
@property
189197
def fixed_formula(self) -> str:
198+
"""Fixed-effects formula string built from ``lhs`` and ``rhs``."""
190199
return f"{self.lhs} ~ {self.rhs}"
191200

192201
@property
193202
def formula(self) -> Formula:
203+
"""Formulaic ``Formula`` instance for fixed-effects materialization."""
194204
return Formula(self.fixed_formula)
195205

196206
def __str__(self) -> str:
@@ -202,8 +212,8 @@ def __str__(self) -> str:
202212
def get_model_matrix(
203213
self,
204214
data: Any,
205-
context: Optional[Mapping[str, Any]] = None,
206-
drop_rows: Optional[set[int]] = None,
215+
context: Mapping[str, Any] | None = None,
216+
drop_rows: set[int] | None = None,
207217
**attr_overrides: Any,
208218
) -> MixedModelMatrices:
209219
"""
@@ -253,7 +263,7 @@ def get_model_matrix(
253263
Z = pd.DataFrame(index=rhs.index)
254264
random_model_spec = None
255265
random_effect_names: list[str] = []
256-
group = {
266+
group: dict[str, np.ndarray | int | list[Any] | str | None] = {
257267
"variable": None,
258268
"labels": [],
259269
"n_groups": 0,

causalpy/tests/test_formula.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
1-
import pytest
1+
# Copyright 2026 - 2026 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
214
import numpy as np
15+
import pytest
316

417
from causalpy.formula import Parser, parse_formula
518

0 commit comments

Comments
 (0)