diff --git a/python/sedonadb/python/sedonadb/dataframe.py b/python/sedonadb/python/sedonadb/dataframe.py index 8777a0ee3..1e5ee3c5b 100644 --- a/python/sedonadb/python/sedonadb/dataframe.py +++ b/python/sedonadb/python/sedonadb/dataframe.py @@ -539,6 +539,63 @@ def drop(self, *cols: str) -> "DataFrame": return DataFrame(self._ctx, self._impl.drop_columns(list(cols)), self._options) + def agg(self, *exprs: Expr, **named_exprs: Expr) -> "DataFrame": + """Aggregate the entire DataFrame to a single row. + + Aggregate expressions can be passed positionally or as keyword + arguments. With keyword arguments the keyword becomes the + output column name — `df.agg(total=sd.funcs.sum(sd.col("x")))` + is shorthand for + `df.agg(sd.funcs.sum(sd.col("x")).alias("total"))`. The two + forms can be mixed in a single call. + + Args: + *exprs: Positional aggregate expressions. + **named_exprs: Keyword aggregate expressions; each keyword + is applied as the output alias of the corresponding + expression. + + Examples: + + >>> sd = sedona.db.connect() + >>> df = sd.sql("SELECT * FROM (VALUES (1), (2), (3), (4)) AS t(x)") + >>> df.agg(sd.funcs.sum(sd.col("x")).alias("total")).show() + ┌───────┐ + │ total │ + │ int64 │ + ╞═══════╡ + │ 10 │ + └───────┘ + >>> df.agg(total=sd.funcs.sum(sd.col("x"))).show() + ┌───────┐ + │ total │ + │ int64 │ + ╞═══════╡ + │ 10 │ + └───────┘ + """ + if not exprs and not named_exprs: + raise ValueError("agg() requires at least one aggregate expression") + + for e in exprs: + if not isinstance(e, Expr): + raise TypeError(f"agg() expects Expr arguments, got {type(e).__name__}") + + all_exprs: List[Expr] = list(exprs) + for name, e in named_exprs.items(): + if not isinstance(e, Expr): + raise TypeError( + f"agg() expects Expr keyword values, got {type(e).__name__} " + f"for keyword {name!r}" + ) + all_exprs.append(e.alias(name)) + + return DataFrame( + self._ctx, + self._impl.aggregate([], [e._impl for e in all_exprs]), + self._options, + ) + def limit(self, n: Optional[int], /, *, offset: int = 0) -> "DataFrame": """Limit result to n rows starting at offset diff --git a/python/sedonadb/src/dataframe.rs b/python/sedonadb/src/dataframe.rs index 96cb6f8b3..c33141b66 100644 --- a/python/sedonadb/src/dataframe.rs +++ b/python/sedonadb/src/dataframe.rs @@ -235,6 +235,28 @@ impl InternalDataFrame { Ok(InternalDataFrame::new(inner, self.runtime.clone())) } + /// Aggregate the rows of the DataFrame, optionally partitioned by + /// `group_exprs`. Both inputs are `Vec` so the same Rust + /// method serves global aggregation (`group_exprs` empty, called + /// from `DataFrame.agg`) and grouped aggregation. + /// + /// The Python side guarantees `agg_exprs` is non-empty and that + /// every entry is an `Expr` (vs. a string or other type). It does + /// not verify that each entry is an aggregate-shaped expression — + /// e.g. `col("x")` would pass the Python `isinstance` check but is + /// not a valid aggregate. DataFusion's plan-build catches that case + /// with a clear error, so we don't reimplement the check here. + fn aggregate( + &self, + group_exprs: Vec, + agg_exprs: Vec, + ) -> Result { + let group_exprs: Vec = group_exprs.into_iter().map(|e| e.inner).collect(); + let agg_exprs: Vec = agg_exprs.into_iter().map(|e| e.inner).collect(); + let inner = self.inner.clone().aggregate(group_exprs, agg_exprs)?; + Ok(InternalDataFrame::new(inner, self.runtime.clone())) + } + fn execute<'py>(&self, py: Python<'py>) -> Result { let df = self.inner.clone(); let count = wait_for_future(py, &self.runtime, async move { diff --git a/python/sedonadb/tests/expr/test_dataframe_agg.py b/python/sedonadb/tests/expr/test_dataframe_agg.py new file mode 100644 index 000000000..17d676844 --- /dev/null +++ b/python/sedonadb/tests/expr/test_dataframe_agg.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for DataFrame.agg(*exprs) — global (ungrouped) aggregation. +# Aggregate expressions are built via `con.funcs.(col(...))` +# which walks the engine's aggregate-UDF registry (added in #885). + +import pandas as pd +import pandas.testing as pdt +import pytest + +from sedonadb.dataframe import DataFrame +from sedonadb.expr import col + + +def test_agg_single_sum(con): + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3, 4]})) + out = df.agg(con.funcs.sum(col("x")).alias("total")).to_pandas() + pdt.assert_frame_equal(out, pd.DataFrame({"total": [10]})) + + +def test_agg_single_count(con): + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3]})) + out = df.agg(con.funcs.count(col("x")).alias("n")).to_pandas() + pdt.assert_frame_equal(out, pd.DataFrame({"n": [3]})) + + +def test_agg_min_max(con): + df = con.create_data_frame(pd.DataFrame({"x": [3, 1, 4, 1, 5, 9, 2, 6]})) + out = df.agg( + con.funcs.min(col("x")).alias("lo"), + con.funcs.max(col("x")).alias("hi"), + ).to_pandas() + pdt.assert_frame_equal(out, pd.DataFrame({"lo": [1], "hi": [9]})) + + +def test_agg_avg_over_compound_expr(con): + # con.funcs.avg over an arithmetic Expr exercises the path where + # aggregate exprs are built on top of operator-composed columns. + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})) + out = df.agg(con.funcs.avg(col("x") + col("y")).alias("avg_xy")).to_pandas() + # (11 + 22 + 33) / 3 = 22.0 + pdt.assert_frame_equal(out, pd.DataFrame({"avg_xy": [22.0]})) + + +def test_agg_multiple_aggregates_one_row(con): + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3, 4]})) + out = df.agg( + con.funcs.sum(col("x")).alias("sum_x"), + con.funcs.count(col("x")).alias("n"), + con.funcs.min(col("x")).alias("lo"), + con.funcs.max(col("x")).alias("hi"), + ).to_pandas() + pdt.assert_frame_equal( + out, pd.DataFrame({"sum_x": [10], "n": [4], "lo": [1], "hi": [4]}) + ) + + +def test_agg_returns_lazy_dataframe(con): + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3]})) + out = df.agg(con.funcs.sum(col("x"))) + assert isinstance(out, DataFrame) + + +def test_agg_empty_raises(con): + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3]})) + with pytest.raises(ValueError, match="at least one aggregate expression"): + df.agg() + + +def test_agg_non_expr_arg_raises(con): + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3]})) + with pytest.raises(TypeError, match="agg\\(\\) expects Expr arguments"): + df.agg("x") + + +def test_agg_kwarg_aliases_output_column(con): + # `df.agg(total=sd.funcs.sum(col("x")))` is shorthand for + # `df.agg(sd.funcs.sum(col("x")).alias("total"))`. + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3, 4]})) + out = df.agg(total=con.funcs.sum(col("x"))).to_pandas() + pdt.assert_frame_equal(out, pd.DataFrame({"total": [10]})) + + +def test_agg_mixed_positional_and_kwarg(con): + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3, 4]})) + out = df.agg( + con.funcs.sum(col("x")).alias("sum_x"), + n=con.funcs.count(col("x")), + ).to_pandas() + pdt.assert_frame_equal(out, pd.DataFrame({"sum_x": [10], "n": [4]})) + + +def test_agg_kwarg_non_expr_value_raises(con): + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3]})) + with pytest.raises(TypeError, match="agg\\(\\) expects Expr keyword values"): + df.agg(total="not an expr") + + +def test_agg_chains_with_filter(con): + df = con.create_data_frame(pd.DataFrame({"x": [1, 2, 3, 4]})) + out = ( + df.filter(col("x") > 1).agg(con.funcs.sum(col("x")).alias("total")).to_pandas() + ) + pdt.assert_frame_equal(out, pd.DataFrame({"total": [9]}))