Skip to content

Commit 17a982b

Browse files
committed
Support re-use of accumulator states and multiple evaluations
1 parent 1fd0d27 commit 17a982b

File tree

30 files changed

+191
-103
lines changed

30 files changed

+191
-103
lines changed

datafusion/expr-common/src/accumulator.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ pub trait Accumulator: Send + Sync + Debug {
5858
/// running sum.
5959
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
6060

61-
/// Returns the final aggregate value, consuming the internal state.
61+
/// Returns the final aggregate value, without modifying the internal state.
62+
/// This method is safe to call multiple times.
63+
fn evaluate(&self) -> Result<ScalarValue>;
64+
65+
/// Returns the final aggregate value, possibly consuming the internal state.
6266
///
6367
/// For example, the `SUM` accumulator maintains a running sum,
6468
/// and `evaluate` will produce that running sum as its output.
@@ -69,7 +73,9 @@ pub trait Accumulator: Send + Sync + Debug {
6973
/// This function gets `&mut self` to allow for the accumulator to build
7074
/// arrow-compatible internal state that can be returned without copying
7175
/// when possible (for example distinct strings)
72-
fn evaluate(&mut self) -> Result<ScalarValue>;
76+
fn evaluate_mut(&mut self) -> Result<ScalarValue> {
77+
self.evaluate()
78+
}
7379

7480
/// Returns the allocated size required for this accumulator, in
7581
/// bytes, including `Self`.
@@ -248,7 +254,13 @@ pub trait Accumulator: Send + Sync + Debug {
248254
/// group values group values
249255
/// in partition 0 in partition 1
250256
/// ```
251-
fn state(&mut self) -> Result<Vec<ScalarValue>>;
257+
fn state_mut(&mut self) -> Result<Vec<ScalarValue>> {
258+
self.state()
259+
}
260+
261+
/// Returns the internal state without consuming it; this method is safe to
262+
/// call multiple times
263+
fn state(&self) -> Result<Vec<ScalarValue>>;
252264

253265
/// Updates the accumulator's state from an `Array` containing one
254266
/// or more intermediate values.

datafusion/ffi/src/udaf/accumulator.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ pub struct FFI_Accumulator {
4545

4646
// Evaluate and return a ScalarValues as protobuf bytes
4747
pub evaluate:
48-
unsafe extern "C" fn(accumulator: &mut Self) -> RResult<RVec<u8>, RString>,
48+
unsafe extern "C" fn(accumulator: &Self) -> RResult<RVec<u8>, RString>,
4949

5050
pub size: unsafe extern "C" fn(accumulator: &Self) -> usize,
5151

5252
pub state:
53-
unsafe extern "C" fn(accumulator: &mut Self) -> RResult<RVec<RVec<u8>>, RString>,
53+
unsafe extern "C" fn(accumulator: &Self) -> RResult<RVec<RVec<u8>>, RString>,
5454

5555
pub merge_batch: unsafe extern "C" fn(
5656
accumulator: &mut Self,
@@ -109,9 +109,9 @@ unsafe extern "C" fn update_batch_fn_wrapper(
109109
}
110110

111111
unsafe extern "C" fn evaluate_fn_wrapper(
112-
accumulator: &mut FFI_Accumulator,
112+
accumulator: &FFI_Accumulator,
113113
) -> RResult<RVec<u8>, RString> {
114-
let accumulator = accumulator.inner_mut();
114+
let accumulator = accumulator.inner();
115115

116116
let scalar_result = rresult_return!(accumulator.evaluate());
117117
let proto_result: datafusion_proto::protobuf::ScalarValue =
@@ -125,9 +125,9 @@ unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize {
125125
}
126126

127127
unsafe extern "C" fn state_fn_wrapper(
128-
accumulator: &mut FFI_Accumulator,
128+
accumulator: &FFI_Accumulator,
129129
) -> RResult<RVec<RVec<u8>>, RString> {
130-
let accumulator = accumulator.inner_mut();
130+
let accumulator = accumulator.inner();
131131

132132
let state = rresult_return!(accumulator.state());
133133
let state = state
@@ -237,10 +237,10 @@ impl Accumulator for ForeignAccumulator {
237237
}
238238
}
239239

240-
fn evaluate(&mut self) -> Result<ScalarValue> {
240+
fn evaluate(&self) -> Result<ScalarValue> {
241241
unsafe {
242242
let scalar_bytes =
243-
df_result!((self.accumulator.evaluate)(&mut self.accumulator))?;
243+
df_result!((self.accumulator.evaluate)(&self.accumulator))?;
244244

245245
let proto_scalar =
246246
datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref())
@@ -254,10 +254,10 @@ impl Accumulator for ForeignAccumulator {
254254
unsafe { (self.accumulator.size)(&self.accumulator) }
255255
}
256256

257-
fn state(&mut self) -> Result<Vec<ScalarValue>> {
257+
fn state(&self) -> Result<Vec<ScalarValue>> {
258258
unsafe {
259259
let state_protos =
260-
df_result!((self.accumulator.state)(&mut self.accumulator))?;
260+
df_result!((self.accumulator.state)(&self.accumulator))?;
261261

262262
state_protos
263263
.into_iter()

datafusion/ffi/src/udaf/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ unsafe extern "C" fn state_fields_fn_wrapper(
274274
return_field,
275275
ordering_fields,
276276
is_distinct,
277+
for_sliding: false,
277278
};
278279

279280
let state_fields = rresult_return!(udaf.state_fields(args));
@@ -682,6 +683,7 @@ mod tests {
682683
return_field: Field::new("f", DataType::Float64, true).into(),
683684
ordering_fields: &[Arc::clone(&a_field)],
684685
is_distinct: false,
686+
for_sliding: false,
685687
})?;
686688

687689
assert_eq!(state_fields.len(), 3);

datafusion/functions-aggregate-common/src/accumulator.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ pub struct StateFieldsArgs<'a> {
9999

100100
/// Whether the aggregate function is distinct.
101101
pub is_distinct: bool,
102+
103+
/// Returns fields for the sliding accumulator, if one exists
104+
pub for_sliding: bool,
102105
}
103106

104107
impl StateFieldsArgs<'_> {

datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use arrow::array::{ArrayRef, OffsetSizeTrait};
2121
use datafusion_common::cast::as_list_array;
2222
use datafusion_common::utils::SingleRowListArrayBuilder;
23-
use datafusion_common::ScalarValue;
23+
use datafusion_common::{exec_err, ScalarValue};
2424
use datafusion_expr_common::accumulator::Accumulator;
2525
use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType};
2626
use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewSet;
@@ -45,12 +45,16 @@ impl<O: OffsetSizeTrait> BytesDistinctCountAccumulator<O> {
4545
}
4646

4747
impl<O: OffsetSizeTrait> Accumulator for BytesDistinctCountAccumulator<O> {
48-
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
48+
fn state_mut(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
4949
let set = self.0.take();
5050
let arr = set.into_state();
5151
Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()])
5252
}
5353

54+
fn state(&self) -> datafusion_common::Result<Vec<ScalarValue>> {
55+
exec_err!("immutable state not supported for BytesDistinctCount")
56+
}
57+
5458
fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
5559
if values.is_empty() {
5660
return Ok(());
@@ -80,7 +84,7 @@ impl<O: OffsetSizeTrait> Accumulator for BytesDistinctCountAccumulator<O> {
8084
})
8185
}
8286

83-
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
87+
fn evaluate(&self) -> datafusion_common::Result<ScalarValue> {
8488
Ok(ScalarValue::Int64(Some(self.0.non_null_len() as i64)))
8589
}
8690

@@ -104,12 +108,16 @@ impl BytesViewDistinctCountAccumulator {
104108
}
105109

106110
impl Accumulator for BytesViewDistinctCountAccumulator {
107-
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
111+
fn state_mut(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
108112
let set = self.0.take();
109113
let arr = set.into_state();
110114
Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()])
111115
}
112116

117+
fn state(&self) -> datafusion_common::Result<Vec<ScalarValue>> {
118+
exec_err!("immutable state not suported for BytesViewDistinctCount")
119+
}
120+
113121
fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
114122
if values.is_empty() {
115123
return Ok(());
@@ -139,7 +147,7 @@ impl Accumulator for BytesViewDistinctCountAccumulator {
139147
})
140148
}
141149

142-
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
150+
fn evaluate(&self) -> datafusion_common::Result<ScalarValue> {
143151
Ok(ScalarValue::Int64(Some(self.0.non_null_len() as i64)))
144152
}
145153

datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ impl Accumulator for DictionaryCountAccumulator {
5252
self.inner.update_batch(values.as_slice())
5353
}
5454

55-
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
55+
fn evaluate(&self) -> datafusion_common::Result<ScalarValue> {
5656
self.inner.evaluate()
5757
}
5858

5959
fn size(&self) -> usize {
6060
self.inner.size()
6161
}
6262

63-
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
63+
fn state(&self) -> datafusion_common::Result<Vec<ScalarValue>> {
6464
self.inner.state()
6565
}
6666

datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ where
6868
T: ArrowPrimitiveType + Send + Debug,
6969
T::Native: Eq + Hash,
7070
{
71-
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
71+
fn state(&self) -> datafusion_common::Result<Vec<ScalarValue>> {
7272
let arr = Arc::new(
7373
PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
7474
.with_data_type(self.data_type.clone()),
@@ -111,7 +111,7 @@ where
111111
})
112112
}
113113

114-
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
114+
fn evaluate(&self) -> datafusion_common::Result<ScalarValue> {
115115
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
116116
}
117117

@@ -155,7 +155,7 @@ impl<T> Accumulator for FloatDistinctCountAccumulator<T>
155155
where
156156
T: ArrowPrimitiveType + Send + Debug,
157157
{
158-
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
158+
fn state(&self) -> datafusion_common::Result<Vec<ScalarValue>> {
159159
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
160160
self.values.iter().map(|v| v.0),
161161
)) as ArrayRef;
@@ -198,7 +198,7 @@ where
198198
})
199199
}
200200

201-
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
201+
fn evaluate(&self) -> datafusion_common::Result<ScalarValue> {
202202
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
203203
}
204204

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
325325

326326
let results: Vec<ScalarValue> = states
327327
.into_iter()
328-
.map(|mut state| {
328+
.map(|state| {
329329
self.free_allocation(state.size());
330330
state.accumulator.evaluate()
331331
})
@@ -347,7 +347,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
347347
// which we need to form into columns
348348
let mut results: Vec<Vec<ScalarValue>> = vec![];
349349

350-
for mut state in states {
350+
for state in states {
351351
self.free_allocation(state.size());
352352
let accumulator_state = state.accumulator.state()?;
353353
results.resize_with(accumulator_state.len(), Vec::new);

datafusion/functions-aggregate/src/approx_distinct.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,12 @@ macro_rules! default_accumulator_impl {
186186
Ok(())
187187
}
188188

189-
fn state(&mut self) -> Result<Vec<ScalarValue>> {
189+
fn state(&self) -> Result<Vec<ScalarValue>> {
190190
let value = ScalarValue::from(&self.hll);
191191
Ok(vec![value])
192192
}
193193

194-
fn evaluate(&mut self) -> Result<ScalarValue> {
194+
fn evaluate(&self) -> Result<ScalarValue> {
195195
Ok(ScalarValue::UInt64(Some(self.hll.count() as u64)))
196196
}
197197

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ impl ApproxPercentileAccumulator {
457457
}
458458

459459
impl Accumulator for ApproxPercentileAccumulator {
460-
fn state(&mut self) -> Result<Vec<ScalarValue>> {
460+
fn state(&self) -> Result<Vec<ScalarValue>> {
461461
Ok(self.digest.to_scalar_state().into_iter().collect())
462462
}
463463

@@ -473,7 +473,7 @@ impl Accumulator for ApproxPercentileAccumulator {
473473
Ok(())
474474
}
475475

476-
fn evaluate(&mut self) -> Result<ScalarValue> {
476+
fn evaluate(&self) -> Result<ScalarValue> {
477477
if self.digest.count() == 0 {
478478
return ScalarValue::try_from(self.return_type.clone());
479479
}

0 commit comments

Comments
 (0)