From 39bdba4168d1de79ee24e409375f253d407bf00a Mon Sep 17 00:00:00 2001 From: Jia Yu Date: Sun, 31 May 2026 23:53:48 -0700 Subject: [PATCH] feat(python/sedonadb): add DataFrame.group_by + GroupedDataFrame.agg MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Grouped aggregation on top of the registry-driven function dispatch (#885) and the global-aggregation binding (#887). API: df.group_by("k").agg(total=sd.funcs.sum(sd.col("v"))) df.group_by("k1", "k2").agg( sd.funcs.sum(col("x")).alias("sum_x"), n=sd.funcs.count(col("y")), ) df.group_by(col("x") + col("y")).agg(...) df.group_by(col("k"), "other_key").agg(...) - `df.group_by(*keys)` — varargs of `str | Expr`. Strings auto-promote to `col(name)`; arbitrary `Expr` values are accepted as computed group keys. Empty keys → ValueError; non-str/non-Expr → TypeError. - Returns a new `GroupedDataFrame` — a thin holder for the parent df plus the resolved group exprs. Single method `.agg(*exprs, **named_exprs)` with the same shape as `DataFrame.agg`. Pure Python — the Rust `InternalDataFrame::aggregate(group_exprs, agg_exprs)` from #887 already handles the grouped case; this PR just populates `group_exprs` when constructing the aggregation. The `GroupedDataFrame` intermediate is kept minimal (one method beyond `__init__`) so it stays a clean place to add convenience aggregates (`count`, `size`, etc.) later without polluting `DataFrame`. Tests: 12 covering single/multi string keys, Expr keys, computed Expr keys, mixed str/Expr, positional + kwarg agg, lazy return type, and the empty/bad-type error paths for both `group_by` and its `.agg`. --- python/sedonadb/python/sedonadb/dataframe.py | 95 ++++++++++++ .../tests/expr/test_dataframe_group_by.py | 136 ++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 python/sedonadb/tests/expr/test_dataframe_group_by.py diff --git a/python/sedonadb/python/sedonadb/dataframe.py b/python/sedonadb/python/sedonadb/dataframe.py index 1e5ee3c5b..232a6aa9f 100644 --- a/python/sedonadb/python/sedonadb/dataframe.py +++ b/python/sedonadb/python/sedonadb/dataframe.py @@ -596,6 +596,50 @@ def agg(self, *exprs: Expr, **named_exprs: Expr) -> "DataFrame": self._options, ) + def group_by(self, *keys: Union[str, Expr]) -> "GroupedDataFrame": + """Group rows by one or more keys for aggregation. + + Returns a `GroupedDataFrame` whose `.agg(...)` method runs the + aggregation. Strings are auto-promoted to column references + (same pattern as `sort`); arbitrary `Expr` values are accepted + as computed group keys. + + Args: + *keys: One or more `str` column names or `Expr` group keys. + At least one is required. + + Examples: + + >>> sd = sedona.db.connect() + >>> df = sd.sql( + ... "SELECT * FROM (VALUES ('a', 1), ('a', 2), ('b', 3)) AS t(k, v)" + ... ) + >>> df.group_by("k").agg(total=sd.funcs.sum(sd.col("v"))).sort("k").show() + ┌──────┬───────┐ + │ k ┆ total │ + │ utf8 ┆ int64 │ + ╞══════╪═══════╡ + │ a ┆ 3 │ + ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ b ┆ 3 │ + └──────┴───────┘ + """ + if not keys: + raise ValueError("group_by() requires at least one key") + + coerced: List[Expr] = [] + for k in keys: + if isinstance(k, Expr): + coerced.append(k) + elif isinstance(k, str): + coerced.append(_col(k)) + else: + raise TypeError( + f"group_by() expects str or Expr arguments, got {type(k).__name__}" + ) + + return GroupedDataFrame(self, coerced) + def limit(self, n: Optional[int], /, *, offset: int = 0) -> "DataFrame": """Limit result to n rows starting at offset @@ -1226,6 +1270,57 @@ def _scan_collected_default(ctx_impl, obj, schema, options): return _scan_default(ctx_impl, obj, schema, options).to_memtable() +class GroupedDataFrame: + """A `DataFrame` partitioned by one or more group keys. + + Produced by `DataFrame.group_by(...)`. The class exists as a step + in the chain to simplify aggregation expressions. + """ + + __slots__ = ("_df", "_group_exprs") + + def __init__(self, df: DataFrame, group_exprs: List[Expr]): + self._df = df + self._group_exprs = group_exprs + + def agg(self, *exprs: Expr, **named_exprs: Expr) -> DataFrame: + """Aggregate within each group. + + Same signature as `DataFrame.agg`: positional aggregate `Expr`s + and/or keyword aggregates where the keyword is the output + column name. + + Args: + *exprs: Positional aggregate expressions. + **named_exprs: Keyword aggregate expressions; each keyword + becomes the output alias. + """ + if not exprs and not named_exprs: + raise ValueError("agg() requires at least one aggregate expression") + + for e in exprs: + if not isinstance(e, Expr): + raise TypeError(f"agg() expects Expr arguments, got {type(e).__name__}") + + all_exprs: List[Expr] = list(exprs) + for name, e in named_exprs.items(): + if not isinstance(e, Expr): + raise TypeError( + f"agg() expects Expr keyword values, got {type(e).__name__} " + f"for keyword {name!r}" + ) + all_exprs.append(e.alias(name)) + + return DataFrame( + self._df._ctx, + self._df._impl.aggregate( + [g._impl for g in self._group_exprs], + [e._impl for e in all_exprs], + ), + self._df._options, + ) + + def _scan_geopandas(ctx_impl, obj, schema, options): return _scan_collected_default( ctx_impl, obj.to_arrow(geometry_encoding="WKB"), schema, options diff --git a/python/sedonadb/tests/expr/test_dataframe_group_by.py b/python/sedonadb/tests/expr/test_dataframe_group_by.py new file mode 100644 index 000000000..2601b81a5 --- /dev/null +++ b/python/sedonadb/tests/expr/test_dataframe_group_by.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pandas as pd +import pandas.testing as pdt +import pytest + +from sedonadb.dataframe import DataFrame, GroupedDataFrame +from sedonadb.expr import col + + +def test_group_by_single_key_string(con): + df = con.create_data_frame( + pd.DataFrame({"k": ["a", "a", "b", "b"], "v": [1, 2, 3, 4]}) + ) + out = df.group_by("k").agg(total=con.funcs.sum(col("v"))).sort("k").to_pandas() + pdt.assert_frame_equal(out, pd.DataFrame({"k": ["a", "b"], "total": [3, 7]})) + + +def test_group_by_returns_grouped_dataframe(con): + df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]})) + g = df.group_by("k") + assert isinstance(g, GroupedDataFrame) + + +def test_group_by_multiple_keys(con): + df = con.create_data_frame( + pd.DataFrame( + { + "k1": ["a", "a", "a", "b"], + "k2": ["x", "x", "y", "y"], + "v": [1, 2, 3, 4], + } + ) + ) + out = ( + df.group_by("k1", "k2") + .agg(total=con.funcs.sum(col("v"))) + .sort("k1", "k2") + .to_pandas() + ) + pdt.assert_frame_equal( + out, + pd.DataFrame( + {"k1": ["a", "a", "b"], "k2": ["x", "y", "y"], "total": [3, 3, 4]} + ), + ) + + +def test_group_by_expr_key(con): + # group_by(col("k")) and group_by("k") should produce the same plan. + df = con.create_data_frame(pd.DataFrame({"k": ["a", "a", "b"], "v": [1, 2, 3]})) + out = df.group_by(col("k")).agg(total=con.funcs.sum(col("v"))).sort("k").to_pandas() + pdt.assert_frame_equal(out, pd.DataFrame({"k": ["a", "b"], "total": [3, 3]})) + + +def test_group_by_computed_expr_key(con): + # Group by an arithmetic expression — rows whose x+y matches are in + # the same group. (1,9), (4,6), (5,5) all sum to 10. + df = con.create_data_frame(pd.DataFrame({"x": [1, 4, 5, 2], "y": [9, 6, 5, 3]})) + out = ( + df.group_by((col("x") + col("y")).alias("xy")) + .agg(n=con.funcs.count(col("x"))) + .sort("xy") + .to_pandas() + ) + pdt.assert_frame_equal(out, pd.DataFrame({"xy": [5, 10], "n": [1, 3]})) + + +def test_group_by_mixed_string_and_expr(con): + df = con.create_data_frame(pd.DataFrame({"k": ["a", "a", "b"], "v": [1, 2, 3]})) + out = df.group_by("k", col("v") > 1).agg(n=con.funcs.count(col("v"))).to_pandas() + # Three distinct (k, v>1) tuples: (a, false), (a, true), (b, true). + assert len(out) == 3 + assert sorted(out["n"].tolist()) == [1, 1, 1] + + +def test_group_by_agg_positional_and_kwarg(con): + df = con.create_data_frame(pd.DataFrame({"k": ["a", "a", "b"], "v": [1, 2, 3]})) + out = ( + df.group_by("k") + .agg( + con.funcs.sum(col("v")).alias("sum_v"), + n=con.funcs.count(col("v")), + ) + .sort("k") + .to_pandas() + ) + pdt.assert_frame_equal( + out, + pd.DataFrame({"k": ["a", "b"], "sum_v": [3, 3], "n": [2, 1]}), + ) + + +def test_group_by_agg_returns_lazy_dataframe(con): + df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]})) + out = df.group_by("k").agg(total=con.funcs.sum(col("v"))) + assert isinstance(out, DataFrame) + + +def test_group_by_empty_raises(con): + df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]})) + with pytest.raises(ValueError, match="at least one key"): + df.group_by() + + +def test_group_by_bad_key_type_raises(con): + df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]})) + with pytest.raises(TypeError, match="str or Expr"): + df.group_by(123) + + +def test_grouped_agg_empty_raises(con): + df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]})) + with pytest.raises(ValueError, match="at least one aggregate expression"): + df.group_by("k").agg() + + +def test_grouped_agg_non_expr_raises(con): + df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]})) + with pytest.raises(TypeError, match="agg\\(\\) expects Expr arguments"): + df.group_by("k").agg("not an expr")