Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
618 changes: 547 additions & 71 deletions sidemantic/adapters/metricflow.py

Large diffs are not rendered by default.

96 changes: 94 additions & 2 deletions sidemantic/sql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,15 +1208,32 @@ def extract_from_metric(metric):
"""Recursively extract filter columns from a metric and its dependencies."""
# Extract from the metric's own filters
if metric.filters:
filter_model_name = None
deps = metric.get_dependencies(self.graph)
for dep in deps:
try:
dep_model_name, _ = self.graph.resolve_metric_reference(dep)
except KeyError:
dep_model_name = dep.split(".", 1)[0] if "." in dep else None
if dep_model_name:
add_filter_columns(dep_model_name, metric.filters)
filter_model_name = dep_model_name
break
# A graph-level simple aggregate (``agg`` + ``sql``) has no metric
# dependencies, so resolve its owning model from the SQL's column
# refs. Without this the filter columns are never projected into
# the CTE and the CASE WHEN filter references a missing column.
if filter_model_name is None and metric.agg and metric.sql:
try:
parsed = sqlglot.parse_one(metric.sql, dialect=self.dialect)
for col in parsed.find_all(exp.Column):
candidate = col.table.replace("_cte", "") if col.table else None
if candidate and candidate in self.graph.models:
filter_model_name = candidate
break
except Exception:
filter_model_name = None
if filter_model_name:
add_filter_columns(filter_model_name, metric.filters)

# For ratio metrics, check numerator and denominator
if metric.type == "ratio":
Expand Down Expand Up @@ -2655,6 +2672,55 @@ def _rewrite_model_refs_to_ctes(self, sql_expr: str) -> str:
)
return rewritten

def _infer_metric_filter_model(self, metric, model_context: str | None) -> str | None:
"""Infer the owning model of a graph-level simple aggregate metric.

Used to qualify the metric's unqualified filter columns. Prefers an
explicit ``model_context``; otherwise resolves the model from the first
qualified column in the metric SQL (e.g. ``orders.amount`` -> ``orders``).
"""
if model_context and model_context in self.graph.models:
return model_context
if metric.sql:
try:
parsed = sqlglot.parse_one(metric.sql, dialect=self.dialect)
for col in parsed.find_all(exp.Column):
candidate = col.table.replace("_cte", "") if col.table else None
if candidate and candidate in self.graph.models:
return candidate
except Exception:
return None
return None

def _qualify_metric_filter_sql(self, filter_expr: str, model_name: str | None) -> str:
"""Qualify a metric filter's columns with the owning model's CTE.

Already-qualified ``model.col`` refs are rewritten to ``model_cte.col``;
unqualified columns are anchored to ``model_name``'s CTE so the predicate
is unambiguous when other joined CTEs expose same-named columns. Falls
back to plain CTE rewriting when the owning model is unknown.
"""
if model_name is None:
return self._rewrite_model_refs_to_ctes(filter_expr)
cte_name = self._cte_name(model_name)
cte_identifier = exp.to_identifier(cte_name, quoted=not self._is_simple_identifier(cte_name))
try:
parsed = sqlglot.parse_one(filter_expr, dialect=self.dialect)
except Exception:
return self._rewrite_model_refs_to_ctes(filter_expr)
for col in parsed.find_all(exp.Column):
if col.table:
ref_model = col.table.replace("_cte", "")
if ref_model in self.graph.models:
rewritten = self._cte_name(ref_model)
col.set(
"table",
exp.to_identifier(rewritten, quoted=not self._is_simple_identifier(rewritten)),
)
else:
col.set("table", cte_identifier.copy())
return parsed.sql(dialect=self.dialect)

def _build_measure_aggregation_sql(self, model_name: str, measure) -> str:
"""Build SQL aggregation expression for a measure.

Expand Down Expand Up @@ -2858,13 +2924,39 @@ def resolve_ratio_ref(ref: str) -> str:
if inner_expr != "*":
inner_expr = self._rewrite_model_refs_to_ctes(inner_expr)

# Metric-level filters scope the aggregation to matching rows. Apply
# them via CASE WHEN so the filter only affects this metric (mirrors
# the model-scoped measure path in _build_model_cte). Without this a
# filtered simple metric like a MetricFlow inline measure with
# ``filter: status = 'completed'`` would silently aggregate every row.
filter_sql = None
if metric.filters:
# Qualify unqualified filter columns with the metric's owning model
# CTE so a filter like ``status = 'completed'`` is not ambiguous
# when a joined model's CTE also exposes a ``status`` column.
filter_model = self._infer_metric_filter_model(metric, model_context)
conditions = []
for filter_str in metric.filters:
condition = filter_str.replace("{model}.", "").replace("{model}", "")
conditions.append(self._qualify_metric_filter_sql(condition, filter_model))
if conditions:
filter_sql = " AND ".join(conditions)

if metric.agg == "count":
if filter_sql is not None:
# COUNT ignores NULLs, so emit 1 for matching rows and NULL otherwise.
return f"COUNT(CASE WHEN {filter_sql} THEN 1 ELSE NULL END)"
Comment thread
nicosuave marked this conversation as resolved.
Outdated
if inner_expr == "*":
return "COUNT(*)"
return f"COUNT({inner_expr})"
if metric.agg == "count_distinct":
if filter_sql is not None:
return f"COUNT(DISTINCT CASE WHEN {filter_sql} THEN {inner_expr} ELSE NULL END)"
return f"COUNT(DISTINCT {inner_expr})"
return f"{self._agg_sql_name(metric.agg)}({inner_expr})"
agg_arg = (
f"CASE WHEN {filter_sql} THEN {inner_expr} ELSE NULL END" if filter_sql is not None else inner_expr
)
return f"{self._agg_sql_name(metric.agg)}({agg_arg})"

elif metric.type == "derived" or (not metric.type and not metric.agg and metric.sql):
# Parse formula and replace metric references (handles both typed "derived" and untyped metrics with sql)
Expand Down
112 changes: 89 additions & 23 deletions tests/adapters/metricflow/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ def test_model_table_from_ref(self, graph):
assert model.table == "fct_bookings"

def test_measure_count(self, graph):
"""All 14 measures are imported."""
"""Representable measures are imported; percentile measures are skipped.

The fixture declares 14 measures, but the 4 ``agg: percentile`` measures
have no Sidemantic equivalent and are skipped rather than coerced to sum,
leaving 10.
"""
model = graph.models["bookings_source"]
assert len(model.metrics) == 14
assert len(model.metrics) == 10

def test_measure_names(self, graph):
"""All measure names are present."""
"""Representable measure names are present; percentile measures are absent."""
model = graph.models["bookings_source"]
names = {m.name for m in model.metrics}
expected = {
Expand All @@ -66,12 +71,13 @@ def test_measure_names(self, graph):
"booking_payments",
"referred_bookings",
"median_booking_value",
"booking_value_p99",
"discrete_booking_value_p99",
"approximate_continuous_booking_value_p99",
"approximate_discrete_booking_value_p99",
}
assert names == expected
# ``agg: percentile`` measures cannot be represented and are skipped.
assert "booking_value_p99" not in names
assert "discrete_booking_value_p99" not in names
assert "approximate_continuous_booking_value_p99" not in names
assert "approximate_discrete_booking_value_p99" not in names

def test_sum_boolean_maps_to_sum(self, graph):
"""sum_boolean aggregation maps to sum."""
Expand All @@ -85,11 +91,15 @@ def test_median_aggregation(self, graph):
median = model.get_metric("median_booking_value")
assert median.agg == "median"

def test_percentile_aggregation_fallback(self, graph):
"""percentile aggregation falls through to default (sum) since not in type_mapping."""
def test_percentile_aggregation_skipped(self, graph):
"""percentile aggregation is skipped, not silently coerced to sum.

Sidemantic has no percentile aggregation. Coercing it to ``sum`` would
register ``SUM(booking_value)`` under the percentile measure's name and
return a wrong value, so the measure is dropped instead.
"""
model = graph.models["bookings_source"]
p99 = model.get_metric("booking_value_p99")
assert p99.agg == "sum"
assert model.get_metric("booking_value_p99") is None

def test_count_distinct(self, graph):
"""count_distinct aggregation is mapped correctly."""
Expand Down Expand Up @@ -659,9 +669,13 @@ def test_model_exists(self, graph):
assert "bookings_source" in graph.models

def test_measure_count(self, graph):
"""All 14 measures are imported from the semantic model."""
"""Representable measures are imported; percentile measures are skipped.

Of the 14 declared measures, the 4 ``agg: percentile`` flavors have no
Sidemantic equivalent and are skipped rather than coerced to sum.
"""
model = graph.models["bookings_source"]
assert len(model.metrics) == 14
assert len(model.metrics) == 10

def test_simple_metrics_parsed(self, graph):
"""Simple metrics are parsed as untyped with measure references."""
Expand Down Expand Up @@ -773,11 +787,14 @@ def test_derived_nested(self, graph):
assert "booking_value_sub_instant" in nested.sql

def test_derived_with_alias(self, graph):
"""Derived metric with alias on input parses."""
"""Derived metric with non-offset alias rewrites the alias to its input metric."""
pct = graph.get_metric("non_referred_bookings_pct")
assert pct is not None
assert pct.type == "derived"
assert "ref_bookings" in pct.sql
# The non-offset alias ``ref_bookings`` is rewritten back to its real
# input metric ``referred_bookings`` so the metric is queryable.
assert pct.sql == "(bookings - referred_bookings) * 1.0 / bookings"
assert pct.get_dependencies(graph) == {"bookings", "referred_bookings"}

def test_derived_with_filtered_input(self, graph):
"""Derived metric with filter on input metric parses."""
Expand Down Expand Up @@ -806,14 +823,17 @@ def test_derived_offset_once_twice(self, graph):
assert twice.type == "derived"

def test_derived_shared_aliases(self, graph):
"""Derived metrics with shared alias names parse independently."""
"""Derived metrics with shared alias names rewrite to their own input metrics."""
# Same alias name (``shared_alias``) maps to a different underlying metric
# in each derived metric; each is rewritten independently.
a = graph.get_metric("derived_shared_alias_1a")
assert a is not None
assert a.type == "derived"
assert "shared_alias" in a.sql
assert a.sql == "bookings - 10"

b = graph.get_metric("derived_shared_alias_2")
assert b is not None
assert b.sql == "instant_bookings + 10"

def test_derived_fill_nulls(self, graph):
"""Derived metrics with fill_nulls inputs parse."""
Expand Down Expand Up @@ -850,16 +870,39 @@ def test_ratio_with_metric_filter(self, graph):
assert ratio.type == "ratio"
assert ratio.filters is not None

def test_conversion_metrics_skipped(self, graph):
"""Conversion metrics are skipped (unsupported type)."""
def test_conversion_metrics_parsed(self, graph):
"""Conversion metrics are retained as non-queryable metadata.

MetricFlow conversion metrics reference base/conversion *measures*, which
cannot be faithfully mapped to Sidemantic's event-filter conversion
funnel, so they are captured in metadata rather than registered as
broken queryable metrics.
"""
assert "visit_buy_conversion_rate_7days" not in graph.metrics
assert "visit_buy_conversion_rate" not in graph.metrics
assert "visit_buy_conversions" not in graph.metrics
assert "visit_buy_conversion_rate_by_session" not in graph.metrics

conv_specs = graph.metadata["metricflow_conversion_metrics"]

rate_7d = conv_specs["visit_buy_conversion_rate_7days"]
assert rate_7d["entity"] == "user"
assert rate_7d["base_measure"] == "visits"
assert rate_7d["conversion_measure"] == "buys"
assert rate_7d["window"] == "7 days"
assert rate_7d["calculation"] == "conversion_rate"

# conversions count flavor with a dict conversion_measure (fill_nulls_with)
conversions = conv_specs["visit_buy_conversions"]
assert conversions["conversion_measure"] == "buys"
assert conversions["calculation"] == "conversions"

# constant_properties retained in metadata
by_session = conv_specs["visit_buy_conversion_rate_by_session"]
assert by_session["constant_properties"] == [{"base_property": "session", "conversion_property": "session_id"}]

def test_total_metric_count(self, graph):
"""Verify total number of parsed metrics (simple + cumulative + derived + ratio, excluding conversion)."""
# Conversion metrics are skipped, so we count only supported types
"""Verify total number of parsed metrics (simple + cumulative + derived + ratio + conversion)."""
assert len(graph.metrics) >= 50


Expand Down Expand Up @@ -942,18 +985,41 @@ def graph(self):
return adapter.parse(FIXTURES / "simple_manifest_saved_queries.yaml")

def test_parse_succeeds(self, graph):
"""Fixture parses without errors (saved_queries key is ignored gracefully)."""
"""Fixture parses without errors."""
assert graph is not None

def test_model_exists(self, graph):
"""The placeholder semantic model is imported."""
assert "sales_for_saved_queries" in graph.models

def test_saved_queries_not_in_metrics(self, graph):
"""saved_queries are not parsed into graph.metrics (not yet supported)."""
"""saved_queries are kept separate from graph.metrics."""
assert "p0_booking" not in graph.metrics
assert "p0_booking_with_order_by_and_limit" not in graph.metrics

def test_saved_queries_captured_in_metadata(self, graph):
"""saved_queries are parsed into graph.metadata['saved_queries']."""
saved = graph.metadata.get("saved_queries")
assert saved is not None
assert "p0_booking" in saved
assert "p0_booking_with_order_by_and_limit" in saved
assert "dimensions_only" in saved

p0 = saved["p0_booking"]
assert p0["metrics"] == ["bookings", "instant_bookings"]
assert p0["group_by"] == [
"TimeDimension('metric_time', 'day')",
"Dimension('listing__capacity_latest')",
]
assert p0["where"] == ["{{ Dimension('listing__capacity_latest') }} > 3"]

def test_saved_query_order_by_and_limit(self, graph):
"""order_by and limit are retained on the saved query."""
saved = graph.metadata["saved_queries"]
ordered = saved["p0_booking_with_order_by_and_limit"]
assert ordered["limit"] == 10
assert ordered["order_by"] is not None


# =============================================================================
# SCD listings: validity_params from upstream test suite
Expand Down
Loading
Loading