Skip to content
Open
988 changes: 960 additions & 28 deletions datafusion/optimizer/src/eliminate_join.rs

Large diffs are not rendered by default.

28 changes: 4 additions & 24 deletions datafusion/optimizer/src/optimize_projections/required_indices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

//! [`RequiredIndices`] helper for OptimizeProjection

use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use crate::utils::for_each_referenced_index;
use datafusion_common::tree_node::TreeNodeRecursion;
use datafusion_common::{Column, DFSchemaRef, Result};
use datafusion_expr::{Expr, LogicalPlan};

Expand Down Expand Up @@ -112,29 +113,8 @@ impl RequiredIndices {
/// * `input_schema`: The input schema to analyze for index requirements.
/// * `expr`: An expression for which we want to find necessary field indices.
fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) {
// `apply` does not descend into subqueries, so recurse manually to
// handle those cases.
expr.apply(|e| {
match e {
Expr::Column(c) | Expr::OuterReferenceColumn(_, c) => {
if let Some(idx) = input_schema.maybe_index_of_column(c) {
self.indices.push(idx);
}
}
Expr::ScalarSubquery(sub) => {
self.add_exprs(input_schema, &sub.outer_ref_columns);
}
Expr::Exists(ex) => {
self.add_exprs(input_schema, &ex.subquery.outer_ref_columns);
}
Expr::InSubquery(isq) => {
self.add_exprs(input_schema, &isq.subquery.outer_ref_columns);
}
_ => {}
}
Ok(TreeNodeRecursion::Continue)
})
.expect("traversal is infallible");
for_each_referenced_index(expr, input_schema, |idx| self.indices.push(idx))
.expect("traversal is infallible");
}

/// Like [`Self::add_expr`], but for multiple expressions.
Expand Down
53 changes: 52 additions & 1 deletion datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ use arrow::array::{Array, RecordBatch, new_null_array};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::TableReference;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::tree_node::{TransformedResult, TreeNode};
use datafusion_common::tree_node::{TransformedResult, TreeNode, TreeNodeRecursion};
use datafusion_common::{Column, DFSchema, Result, ScalarValue};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::{Exists, InSubquery, SetComparison};
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::{ColumnarValue, Expr, logical_plan::LogicalPlan};
use datafusion_physical_expr::create_physical_expr;
Expand All @@ -37,6 +38,56 @@ use std::sync::Arc;
/// as it was initially placed here and then moved elsewhere.
pub use datafusion_expr::expr_rewriter::NamePreserver;

/// Invokes `f` with the index, within `schema`, of every column referenced by
/// `expr` — including columns reached through a correlated subquery's outer
/// references. Columns absent from `schema` are skipped.
///
/// A subquery's own plan is intentionally not traversed: its internal columns
/// index into its own schema, not `schema`; only the outer (correlated) columns
/// it references from `schema` are relevant. The comparison expression of an
/// `IN`/set-comparison subquery is reached by the normal expression walk.
///
/// This is the shared primitive behind the top-down "which of a node's output
/// columns does an ancestor still need" analyses, namely
/// [`OptimizeProjections`](crate::optimize_projections::OptimizeProjections)
/// and [`EliminateJoin`](crate::eliminate_join::EliminateJoin). The two keep
/// their own required-index containers (an ordered set vs. a hash set), so this
/// reports indices through a callback rather than populating a shared type.
pub(crate) fn for_each_referenced_index(
expr: &Expr,
schema: &DFSchema,
mut f: impl FnMut(usize),
) -> Result<()> {
visit_referenced_indices(expr, schema, &mut f)
}

fn visit_referenced_indices(
expr: &Expr,
schema: &DFSchema,
f: &mut dyn FnMut(usize),
) -> Result<()> {
expr.apply(|expr| {
match expr {
Expr::Column(column) | Expr::OuterReferenceColumn(_, column) => {
if let Some(idx) = schema.maybe_index_of_column(column) {
f(idx);
}
}
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::SetComparison(SetComparison { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
for outer in &subquery.outer_ref_columns {
visit_referenced_indices(outer, schema, f)?;
}
}
_ => {}
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}

/// Returns true if `expr` contains all columns in `schema_cols`
pub(crate) fn has_all_column_refs(
expr: &Expr,
Expand Down
57 changes: 47 additions & 10 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1333,19 +1333,57 @@ inner join join_t2 on join_t1.t1_id = join_t2.t2_id
----
logical_plan
01)Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[]]
02)--Projection: join_t1.t1_id
03)----Inner Join: join_t1.t1_id = join_t2.t2_id
04)------TableScan: join_t1 projection=[t1_id]
05)------TableScan: join_t2 projection=[t2_id]
02)--LeftSemi Join: join_t1.t1_id = join_t2.t2_id
03)----TableScan: join_t1 projection=[t1_id]
04)----TableScan: join_t2 projection=[t2_id]
physical_plan
01)AggregateExec: mode=FinalPartitioned, gby=[t1_id@0 as t1_id], aggr=[]
02)--RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
03)----AggregateExec: mode=Partial, gby=[t1_id@0 as t1_id], aggr=[]
04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
04)------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)]
05)--------DataSourceExec: partitions=1, partition_sizes=[1]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------DataSourceExec: partitions=1, partition_sizes=[1]

statement ok
set datafusion.explain.logical_plan_only = true;

# A single `count(DISTINCT col)` over a join whose other side is used only as an
# existence filter can be rewritten to a semi join.
query TT
EXPLAIN
select join_t1.t1_id, count(distinct join_t1.t1_int)
from join_t1
inner join join_t2 on join_t1.t1_id = join_t2.t2_id
group by join_t1.t1_id
----
logical_plan
01)Projection: join_t1.t1_id, count(alias1) AS count(DISTINCT join_t1.t1_int)
02)--Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[count(alias1)]]
03)----Aggregate: groupBy=[[join_t1.t1_id, join_t1.t1_int AS alias1]], aggr=[[]]
04)------LeftSemi Join: join_t1.t1_id = join_t2.t2_id
05)--------TableScan: join_t1 projection=[t1_id, t1_int]
06)--------TableScan: join_t2 projection=[t2_id]

# A similar query with two DISTINCT aggregates is currently not rewritten
# TODO: https://github.com/apache/datafusion/issues/22644
query TT
EXPLAIN
select join_t1.t1_id, count(distinct join_t1.t1_int), count(distinct join_t1.t1_name)
from join_t1
inner join join_t2 on join_t1.t1_id = join_t2.t2_id
group by join_t1.t1_id
----
logical_plan
01)Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[count(DISTINCT join_t1.t1_int), count(DISTINCT join_t1.t1_name)]]
02)--Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int
03)----Inner Join: join_t1.t1_id = join_t2.t2_id
04)------TableScan: join_t1 projection=[t1_id, t1_name, t1_int]
05)------TableScan: join_t2 projection=[t2_id]

statement ok
set datafusion.explain.logical_plan_only = false;

# Join on struct
query TT
explain select join_t3.s3, join_t4.s4
Expand Down Expand Up @@ -1411,10 +1449,9 @@ logical_plan
01)Projection: count(alias1) AS count(DISTINCT join_t1.t1_id)
02)--Aggregate: groupBy=[[]], aggr=[[count(alias1)]]
03)----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]]
04)------Projection: join_t1.t1_id
05)--------Inner Join: join_t1.t1_id = join_t2.t2_id
06)----------TableScan: join_t1 projection=[t1_id]
07)----------TableScan: join_t2 projection=[t2_id]
04)------LeftSemi Join: join_t1.t1_id = join_t2.t2_id
05)--------TableScan: join_t1 projection=[t1_id]
06)--------TableScan: join_t2 projection=[t2_id]
physical_plan
01)ProjectionExec: expr=[count(alias1)@0 as count(DISTINCT join_t1.t1_id)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(alias1)]
Expand All @@ -1423,7 +1460,7 @@ physical_plan
05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[]
06)----------RepartitionExec: partitioning=Hash([alias1@0], 2), input_partitions=2
07)------------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[]
08)--------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
08)--------------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)]
09)----------------DataSourceExec: partitions=1, partition_sizes=[1]
10)----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
11)------------------DataSourceExec: partitions=1, partition_sizes=[1]
Expand Down
14 changes: 7 additions & 7 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ where c_acctbal < (
logical_plan
01)Sort: customer.c_custkey ASC NULLS LAST
02)--Projection: customer.c_custkey
03)----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice)
03)----LeftSemi Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice)
04)------TableScan: customer projection=[c_custkey, c_acctbal]
05)------SubqueryAlias: __scalar_sq_1
06)--------Projection: sum(orders.o_totalprice), orders.o_custkey
07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]]
08)------------Projection: orders.o_custkey, orders.o_totalprice
09)--------------Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.price
09)--------------LeftSemi Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.price
10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]
11)----------------SubqueryAlias: __scalar_sq_2
12)------------------Projection: sum(lineitem.l_extendedprice) AS price, lineitem.l_orderkey
Expand Down Expand Up @@ -555,7 +555,7 @@ logical_plan
02)--TableScan: t0 projection=[t0_id, t0_name]
03)--SubqueryAlias: __correlated_sq_2
04)----Projection: t1.t1_name
05)------Inner Join: t1.t1_id = t2.t2_id
05)------LeftSemi Join: t1.t1_id = t2.t2_id
06)--------TableScan: t1 projection=[t1_id, t1_name]
07)--------TableScan: t2 projection=[t2_id]

Expand All @@ -568,7 +568,7 @@ logical_plan
02)--TableScan: t0 projection=[t0_id, t0_name]
03)--SubqueryAlias: __correlated_sq_1
04)----Projection: t2.t2_name
05)------Inner Join: t1.t1_id = t2.t2_id
05)------RightSemi Join: t1.t1_id = t2.t2_id
06)--------TableScan: t1 projection=[t1_id]
07)--------SubqueryAlias: t2
08)----------TableScan: t2 projection=[t2_id, t2_name]
Expand Down Expand Up @@ -1675,7 +1675,7 @@ where c_acctbal < (
logical_plan
01)Sort: customer.c_custkey ASC NULLS LAST
02)--Projection: customer.c_custkey
03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
03)----LeftSemi Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
04)------TableScan: customer projection=[c_custkey, c_acctbal]
05)------SubqueryAlias: __scalar_sq_2
06)--------Projection: sum(orders.o_totalprice), orders.o_custkey
Expand All @@ -1701,7 +1701,7 @@ where c_acctbal < (
logical_plan
01)Sort: customer.c_custkey ASC NULLS LAST
02)--Projection: customer.c_custkey
03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
03)----LeftSemi Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
04)------TableScan: customer projection=[c_custkey, c_acctbal]
05)------SubqueryAlias: __scalar_sq_2
06)--------Projection: sum(orders.o_totalprice), orders.o_custkey
Expand Down Expand Up @@ -1746,7 +1746,7 @@ WHERE e1.salary > (
----
logical_plan
01)Projection: e1.employee_name, e1.salary
02)--Inner Join: e1.dept_id = __scalar_sq_1.dept_id Filter: CAST(e1.salary AS Decimal128(38, 14)) > __scalar_sq_1.avg(e2.salary)
02)--LeftSemi Join: e1.dept_id = __scalar_sq_1.dept_id Filter: CAST(e1.salary AS Decimal128(38, 14)) > __scalar_sq_1.avg(e2.salary)
03)----SubqueryAlias: e1
04)------TableScan: employees projection=[employee_name, dept_id, salary]
05)----SubqueryAlias: __scalar_sq_1
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ logical_plan
05)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15))
06)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]]
07)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost
08)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey
08)--------------LeftSemi Join: supplier.s_nationkey = nation.n_nationkey
09)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
10)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
11)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost]
Expand All @@ -64,7 +64,7 @@ logical_plan
15)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")]
16)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]]
17)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost
18)----------Inner Join: supplier.s_nationkey = nation.n_nationkey
18)----------LeftSemi Join: supplier.s_nationkey = nation.n_nationkey
19)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
20)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
21)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
Expand All @@ -81,7 +81,7 @@ physical_plan
06)----------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
07)------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4
08)--------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2]
09)----------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2]
10)------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4
11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5]
12)----------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4
Expand All @@ -96,7 +96,7 @@ physical_plan
21)----AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
22)------CoalescePartitionsExec
23)--------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
24)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1]
24)----------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1]
25)------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4
26)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4]
27)----------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4
Expand Down
Loading
Loading