Skip to content

Commit 3ec83a7

Browse files
committed
Merge remote-tracking branch 'origin/main' into docs/python-env-guidance-fallback
2 parents 87adc8c + eec45f7 commit 3ec83a7

6 files changed

Lines changed: 219 additions & 66 deletions

File tree

.github/skills/pr-to-green/SKILL.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,22 @@ Escalation report must include:
106106
- resolution options (1-2) with trade-offs
107107
- recommended next step and what decision is needed from maintainer
108108

109+
### Scope drift / upstream compatibility escalation (required)
110+
111+
Do not keep patching silently when greening a PR turns into broader compatibility work outside the PR's feature surface. Stop after root-cause identification and ask the maintainer whether to continue in the PR or split the work into a separate branch/PR from `main`.
112+
113+
Escalate by default when:
114+
- the failing checks are caused by third-party API or version drift in shared/core code
115+
- the fix would benefit multiple open PRs or the default branch, not just the current PR
116+
- the next fix would modify shared integrations beyond the feature the PR is introducing
117+
- successive CI reruns keep exposing new failure families outside the original PR scope
118+
119+
Scope-drift report must include:
120+
- the original PR goal and the newly discovered broader issue
121+
- which files are feature-specific versus shared/core compatibility files
122+
- options: patch in the PR, split a separate compatibility PR, or pause for maintainer direction
123+
- the recommended next step and why
124+
109125
## CausalPy guardrails
110126
- Never use destructive git commands unless explicitly requested
111127
- Do not create ad hoc test scripts; use `pytest` tests in `causalpy/tests/`

causalpy/pymc_models.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Custom PyMC models for causal inference"""
1515

16+
import inspect
1617
import warnings
1718
from typing import Any, Literal
1819

@@ -30,6 +31,44 @@
3031
from causalpy.variable_selection_priors import VariableSelectionPrior
3132

3233

34+
def _as_xtensor_obs_ind(x: Any) -> Any:
35+
"""Convert a tensor-like input to an obs_ind xtensor."""
36+
import pytensor.xtensor as ptx
37+
38+
return ptx.as_xtensor(x, dims=("obs_ind",))
39+
40+
41+
def _uses_xtensor_api(function: Any) -> bool:
42+
"""Return True when an upstream transform expects xtensor inputs."""
43+
try:
44+
return "as_xtensor" in inspect.getsource(function)
45+
except (OSError, TypeError):
46+
code = getattr(function, "__code__", None)
47+
return code is not None and "as_xtensor" in code.co_names
48+
49+
50+
def _call_time_component_apply(
51+
component: Any,
52+
t: Any,
53+
) -> Any:
54+
"""Call time components across tensor and xtensor variants."""
55+
parameters = inspect.signature(component.apply).parameters
56+
if _uses_xtensor_api(component.apply) or (
57+
"sum" in parameters and "result_callback" not in parameters
58+
):
59+
t = _as_xtensor_obs_ind(t)
60+
result = component.apply(t)
61+
return getattr(result, "values", result)
62+
63+
64+
def _call_seasonality_component_apply(
65+
seasonality_component: Any,
66+
dayofperiod: Any,
67+
) -> Any:
68+
"""Call seasonality components across tensor and xtensor variants."""
69+
return _call_time_component_apply(seasonality_component, dayofperiod)
70+
71+
3372
class PyMCModel(pm.Model):
3473
"""A wrapper class for PyMC models. This provides a scikit-learn like interface with
3574
methods like `fit`, `predict`, and `score`. It also provides other methods which are
@@ -1441,12 +1480,16 @@ def build_model(
14411480
# Seasonal component
14421481
season_component = pm.Deterministic(
14431482
"season_component",
1444-
seasonality_component_instance.apply(t_season_data),
1483+
_call_seasonality_component_apply(
1484+
seasonality_component_instance, t_season_data
1485+
),
14451486
dims="obs_ind",
14461487
)
14471488

14481489
# Trend component
1449-
trend_component_values = trend_component_instance.apply(t_trend_data)
1490+
trend_component_values = _call_time_component_apply(
1491+
trend_component_instance, t_trend_data
1492+
)
14501493
trend_component = pm.Deterministic(
14511494
"trend_component",
14521495
trend_component_values,
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2022 - 2026 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Compatibility tests for optional pymc-marketing components in pymc_models."""
15+
16+
import pytensor.tensor as pt
17+
18+
from causalpy.pymc_models import (
19+
_call_seasonality_component_apply,
20+
_call_time_component_apply,
21+
_uses_xtensor_api,
22+
)
23+
24+
25+
def test_call_seasonality_component_apply_supports_tensor_input():
26+
calls: dict[str, object] = {}
27+
28+
class FakeSeasonalityComponent:
29+
def apply(self, dayofperiod, result_callback=None): # noqa: ANN001, ARG002
30+
calls["dayofperiod"] = dayofperiod
31+
return "seasonality-result"
32+
33+
signal = pt.vector("signal")
34+
35+
result = _call_seasonality_component_apply(FakeSeasonalityComponent(), signal)
36+
37+
assert result == "seasonality-result"
38+
assert calls == {"dayofperiod": signal}
39+
40+
41+
def test_call_seasonality_component_apply_supports_xtensor_signature():
42+
calls: dict[str, object] = {}
43+
44+
class FakeXTensorResult:
45+
def __init__(self, values: str) -> None:
46+
self.values = values
47+
48+
class FakeSeasonalityComponent:
49+
def apply(self, dayofperiod, sum=True): # noqa: ANN001, FBT002
50+
calls["has_dims"] = hasattr(dayofperiod.type, "dims")
51+
calls["sum"] = sum
52+
return FakeXTensorResult("xtensor-result")
53+
54+
result = _call_seasonality_component_apply(
55+
FakeSeasonalityComponent(),
56+
pt.vector("signal"),
57+
)
58+
59+
assert result == "xtensor-result"
60+
assert calls == {
61+
"has_dims": True,
62+
"sum": True,
63+
}
64+
65+
66+
def test_call_time_component_apply_supports_tensor_input():
67+
calls: dict[str, object] = {}
68+
69+
class FakeTimeComponent:
70+
def apply(self, t): # noqa: ANN001
71+
calls["t"] = t
72+
return "time-result"
73+
74+
signal = pt.vector("signal")
75+
76+
result = _call_time_component_apply(FakeTimeComponent(), signal)
77+
78+
assert result == "time-result"
79+
assert calls == {"t": signal}
80+
81+
82+
def test_call_time_component_apply_supports_xtensor_source():
83+
calls: dict[str, object] = {}
84+
85+
class FakeXTensorResult:
86+
def __init__(self, values: str) -> None:
87+
self.values = values
88+
89+
class FakeTimeComponent:
90+
def apply(self, t): # noqa: ANN001
91+
# as_xtensor compatibility path
92+
calls["has_dims"] = hasattr(t.type, "dims")
93+
return FakeXTensorResult("xtensor-time-result")
94+
95+
result = _call_time_component_apply(FakeTimeComponent(), pt.vector("signal"))
96+
97+
assert result == "xtensor-time-result"
98+
assert calls == {"has_dims": True}
99+
100+
101+
def test_uses_xtensor_api_falls_back_to_code_names(monkeypatch):
102+
namespace: dict[str, object] = {}
103+
exec( # noqa: S102
104+
"def fake_transform(x):\n return as_xtensor(x)\n",
105+
{},
106+
namespace,
107+
)
108+
fake_transform = namespace["fake_transform"]
109+
110+
def raise_oserror(_function): # noqa: ANN001
111+
raise OSError("source unavailable")
112+
113+
monkeypatch.setattr("causalpy.pymc_models.inspect.getsource", raise_oserror)
114+
115+
assert _uses_xtensor_api(fake_transform)

0 commit comments

Comments
 (0)