Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions python/sedonadb/python/sedonadb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,50 @@ def agg(self, *exprs: Expr, **named_exprs: Expr) -> "DataFrame":
self._options,
)

def group_by(self, *keys: Union[str, Expr]) -> "GroupedDataFrame":
"""Group rows by one or more keys for aggregation.

Returns a `GroupedDataFrame` whose `.agg(...)` method runs the
aggregation. Strings are auto-promoted to column references
(same pattern as `sort`); arbitrary `Expr` values are accepted
as computed group keys.

Args:
*keys: One or more `str` column names or `Expr` group keys.
At least one is required.

Examples:

>>> sd = sedona.db.connect()
>>> df = sd.sql(
... "SELECT * FROM (VALUES ('a', 1), ('a', 2), ('b', 3)) AS t(k, v)"
... )
>>> df.group_by("k").agg(total=sd.funcs.sum(sd.col("v"))).sort("k").show()
┌──────┬───────┐
│ k ┆ total │
│ utf8 ┆ int64 │
╞══════╪═══════╡
│ a ┆ 3 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b ┆ 3 │
└──────┴───────┘
"""
if not keys:
raise ValueError("group_by() requires at least one key")

coerced: List[Expr] = []
for k in keys:
if isinstance(k, Expr):
coerced.append(k)
elif isinstance(k, str):
coerced.append(_col(k))
else:
raise TypeError(
f"group_by() expects str or Expr arguments, got {type(k).__name__}"
)

return GroupedDataFrame(self, coerced)

def limit(self, n: Optional[int], /, *, offset: int = 0) -> "DataFrame":
"""Limit result to n rows starting at offset

Expand Down Expand Up @@ -1226,6 +1270,57 @@ def _scan_collected_default(ctx_impl, obj, schema, options):
return _scan_default(ctx_impl, obj, schema, options).to_memtable()


class GroupedDataFrame:
"""A `DataFrame` partitioned by one or more group keys.

Produced by `DataFrame.group_by(...)`. The class exists as a step
in the chain to simplify aggregation expressions.
"""

__slots__ = ("_df", "_group_exprs")

def __init__(self, df: DataFrame, group_exprs: List[Expr]):
self._df = df
self._group_exprs = group_exprs

def agg(self, *exprs: Expr, **named_exprs: Expr) -> DataFrame:
"""Aggregate within each group.

Same signature as `DataFrame.agg`: positional aggregate `Expr`s
and/or keyword aggregates where the keyword is the output
column name.

Args:
*exprs: Positional aggregate expressions.
**named_exprs: Keyword aggregate expressions; each keyword
becomes the output alias.
"""
if not exprs and not named_exprs:
raise ValueError("agg() requires at least one aggregate expression")

for e in exprs:
if not isinstance(e, Expr):
raise TypeError(f"agg() expects Expr arguments, got {type(e).__name__}")

all_exprs: List[Expr] = list(exprs)
for name, e in named_exprs.items():
if not isinstance(e, Expr):
raise TypeError(
f"agg() expects Expr keyword values, got {type(e).__name__} "
f"for keyword {name!r}"
)
all_exprs.append(e.alias(name))

return DataFrame(
self._df._ctx,
self._df._impl.aggregate(
[g._impl for g in self._group_exprs],
[e._impl for e in all_exprs],
),
self._df._options,
)


def _scan_geopandas(ctx_impl, obj, schema, options):
return _scan_collected_default(
ctx_impl, obj.to_arrow(geometry_encoding="WKB"), schema, options
Expand Down
136 changes: 136 additions & 0 deletions python/sedonadb/tests/expr/test_dataframe_group_by.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pandas as pd
import pandas.testing as pdt
import pytest

from sedonadb.dataframe import DataFrame, GroupedDataFrame
from sedonadb.expr import col


def test_group_by_single_key_string(con):
df = con.create_data_frame(
pd.DataFrame({"k": ["a", "a", "b", "b"], "v": [1, 2, 3, 4]})
)
out = df.group_by("k").agg(total=con.funcs.sum(col("v"))).sort("k").to_pandas()
pdt.assert_frame_equal(out, pd.DataFrame({"k": ["a", "b"], "total": [3, 7]}))


def test_group_by_returns_grouped_dataframe(con):
df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]}))
g = df.group_by("k")
assert isinstance(g, GroupedDataFrame)


def test_group_by_multiple_keys(con):
df = con.create_data_frame(
pd.DataFrame(
{
"k1": ["a", "a", "a", "b"],
"k2": ["x", "x", "y", "y"],
"v": [1, 2, 3, 4],
}
)
)
out = (
df.group_by("k1", "k2")
.agg(total=con.funcs.sum(col("v")))
.sort("k1", "k2")
.to_pandas()
)
pdt.assert_frame_equal(
out,
pd.DataFrame(
{"k1": ["a", "a", "b"], "k2": ["x", "y", "y"], "total": [3, 3, 4]}
),
)


def test_group_by_expr_key(con):
# group_by(col("k")) and group_by("k") should produce the same plan.
df = con.create_data_frame(pd.DataFrame({"k": ["a", "a", "b"], "v": [1, 2, 3]}))
out = df.group_by(col("k")).agg(total=con.funcs.sum(col("v"))).sort("k").to_pandas()
pdt.assert_frame_equal(out, pd.DataFrame({"k": ["a", "b"], "total": [3, 3]}))


def test_group_by_computed_expr_key(con):
# Group by an arithmetic expression — rows whose x+y matches are in
# the same group. (1,9), (4,6), (5,5) all sum to 10.
df = con.create_data_frame(pd.DataFrame({"x": [1, 4, 5, 2], "y": [9, 6, 5, 3]}))
out = (
df.group_by((col("x") + col("y")).alias("xy"))
.agg(n=con.funcs.count(col("x")))
.sort("xy")
.to_pandas()
)
pdt.assert_frame_equal(out, pd.DataFrame({"xy": [5, 10], "n": [1, 3]}))


def test_group_by_mixed_string_and_expr(con):
df = con.create_data_frame(pd.DataFrame({"k": ["a", "a", "b"], "v": [1, 2, 3]}))
out = df.group_by("k", col("v") > 1).agg(n=con.funcs.count(col("v"))).to_pandas()
# Three distinct (k, v>1) tuples: (a, false), (a, true), (b, true).
assert len(out) == 3
assert sorted(out["n"].tolist()) == [1, 1, 1]


def test_group_by_agg_positional_and_kwarg(con):
df = con.create_data_frame(pd.DataFrame({"k": ["a", "a", "b"], "v": [1, 2, 3]}))
out = (
df.group_by("k")
.agg(
con.funcs.sum(col("v")).alias("sum_v"),
n=con.funcs.count(col("v")),
)
.sort("k")
.to_pandas()
)
pdt.assert_frame_equal(
out,
pd.DataFrame({"k": ["a", "b"], "sum_v": [3, 3], "n": [2, 1]}),
)


def test_group_by_agg_returns_lazy_dataframe(con):
df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]}))
out = df.group_by("k").agg(total=con.funcs.sum(col("v")))
assert isinstance(out, DataFrame)


def test_group_by_empty_raises(con):
df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]}))
with pytest.raises(ValueError, match="at least one key"):
df.group_by()


def test_group_by_bad_key_type_raises(con):
df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]}))
with pytest.raises(TypeError, match="str or Expr"):
df.group_by(123)


def test_grouped_agg_empty_raises(con):
df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]}))
with pytest.raises(ValueError, match="at least one aggregate expression"):
df.group_by("k").agg()


def test_grouped_agg_non_expr_raises(con):
df = con.create_data_frame(pd.DataFrame({"k": ["a"], "v": [1]}))
with pytest.raises(TypeError, match="agg\\(\\) expects Expr arguments"):
df.group_by("k").agg("not an expr")