diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1bfecd06c2228..341e2d0eaf396 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -705,20 +705,36 @@ impl LogicalPlan { })) } LogicalPlan::Union(Union { inputs, schema }) => { + // Fast path: check structural compatibility against all inputs. + // + // For position-based Union (try_new), schema names come exclusively + // from inputs[0], so we check names/qualifiers only against the first + // input. Data types and nullability must match across every input. let first_input_schema = inputs[0].schema(); - if schema.fields().len() == first_input_schema.fields().len() { - // If inputs are not pruned do not change schema - Ok(LogicalPlan::Union(Union { inputs, schema })) - } else { - // A note on `Union`s constructed via `try_new_by_name`: - // - // At this point, the schema for each input should have - // the same width. Thus, we do not need to save whether a - // `Union` was created `BY NAME`, and can safely rely on the - // `try_new` initializer to derive the new schema based on - // column positions. - Ok(LogicalPlan::Union(Union::try_new(inputs)?)) + let names_match = + schema.fields().len() == first_input_schema.fields().len() + && schema.iter().zip(first_input_schema.iter()).all( + |((q1, f1), (q2, f2))| q1 == q2 && f1.name() == f2.name(), + ); + + let types_match = names_match + && inputs.iter().all(|input| { + let input_schema = input.schema(); + schema.fields().len() == input_schema.fields().len() + && schema.iter().zip(input_schema.iter()).all( + |((_, f1), (_, f2))| { + f1.data_type() == f2.data_type() + && f1.is_nullable() == f2.is_nullable() + }, + ) + }); + + if types_match { + return Ok(LogicalPlan::Union(Union { inputs, schema })); } + + // Slow path: recompute schema with metadata preservation. + Ok(LogicalPlan::Union(Union::try_new_with_metadata(inputs)?)) } LogicalPlan::Distinct(distinct) => { let distinct = match distinct { @@ -3156,6 +3172,52 @@ impl Union { Ok(Union { inputs, schema }) } + /// Constructs a new Union from inputs, deriving the schema by position + /// (like `try_new`) but preserving schema-level and field-level metadata + /// using "later takes precedence" (extend) semantics — matching the + /// behavior of `coerce_union_schema_with_schema`. + pub fn try_new_with_metadata(inputs: Vec>) -> Result { + let mut union = Self::try_new(inputs)?; + + // Merge schema-level metadata: later inputs take precedence. + let mut merged_schema_meta = union.inputs[0].schema().metadata().clone(); + for input in union.inputs.iter().skip(1) { + merged_schema_meta.extend(input.schema().metadata().clone()); + } + + // Merge field-level metadata: later inputs take precedence per field. + let mut merged_field_meta: Vec<_> = union.inputs[0] + .schema() + .fields() + .iter() + .map(|f| f.metadata().clone()) + .collect(); + for input in union.inputs.iter().skip(1) { + for (field_meta, input_field) in + merged_field_meta.iter_mut().zip(input.schema().fields()) + { + field_meta.extend(input_field.metadata().clone()); + } + } + + // Rebuild schema with merged metadata applied to each field. + let new_fields = union + .schema + .iter() + .zip(merged_field_meta) + .map(|((qualifier, field), meta)| { + let mut field = field.as_ref().clone(); + field.set_metadata(meta); + (qualifier.cloned(), Arc::new(field)) + }) + .collect::>(); + + union.schema = + Arc::new(DFSchema::new_with_metadata(new_fields, merged_schema_meta)?); + + Ok(union) + } + /// When constructing a `UNION BY NAME`, we need to wrap inputs /// in an additional `Projection` to account for absence of columns /// in input schemas or differing projection orders. @@ -6285,4 +6347,226 @@ mod tests { Ok(()) } + + #[test] + fn test_recompute_schema_union_type_mismatch() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + + let schema_i32 = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema_i64 = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + + // Build a Union whose schema starts out as Int32 (matching its inputs). + let original = Union::try_new(vec![ + Arc::new(table_scan(Some("t1"), &schema_i32, None)?.build()?), + Arc::new(table_scan(Some("t2"), &schema_i32, None)?.build()?), + ])?; + assert_eq!( + original.schema.field(0).data_type(), + &DataType::Int32, + "sanity: starting schema is Int32" + ); + + // Simulate a rewrite pass (e.g. type-coercion) that replaced the inputs + // with Int64-typed versions while leaving the Union's cached schema stale. + // Same width, different types — this is exactly the bug scenario. + let stale = LogicalPlan::Union(Union { + inputs: vec![ + Arc::new(table_scan(Some("t1"), &schema_i64, None)?.build()?), + Arc::new(table_scan(Some("t2"), &schema_i64, None)?.build()?), + ], + schema: Arc::clone(&original.schema), + }); + + let recomputed = stale.recompute_schema()?; + + assert_eq!( + recomputed.schema().field(0).data_type(), + &DataType::Int64, + "Union schema should track the new Int64 input types after \ + recompute_schema(), but the width-only check left it stale" + ); + + Ok(()) + } + + #[test] + fn test_recompute_schema_union_name_mismatch() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + + let schema_a = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema_b = Schema::new(vec![Field::new("b", DataType::Int32, false)]); + + // Build a Union whose schema starts out with column "a". + let original = Union::try_new(vec![ + Arc::new(table_scan(Some("t1"), &schema_a, None)?.build()?), + Arc::new(table_scan(Some("t2"), &schema_a, None)?.build()?), + ])?; + assert_eq!( + original.schema.field(0).name(), + "a", + "sanity: starting schema has column name 'a'" + ); + + // Simulate a rewrite pass that renamed the columns but left + // the cached schema stale. Same width and type, different name. + let stale = LogicalPlan::Union(Union { + inputs: vec![ + Arc::new(table_scan(Some("t1"), &schema_b, None)?.build()?), + Arc::new(table_scan(Some("t2"), &schema_b, None)?.build()?), + ], + schema: Arc::clone(&original.schema), + }); + + let recomputed = stale.recompute_schema()?; + + assert_eq!( + recomputed.schema().field(0).name(), + "b", + "Union schema should reflect the renamed column after \ + recompute_schema(), but the width-only check left it stale" + ); + + Ok(()) + } + + #[test] + fn test_recompute_schema_union_nullability_mismatch() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + + // nullable: false + let schema_not_null = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + // nullable: true + let schema_nullable = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + // Build Union starting with NOT NULL inputs. + let original = Union::try_new(vec![ + Arc::new(table_scan(Some("t1"), &schema_not_null, None)?.build()?), + Arc::new(table_scan(Some("t2"), &schema_not_null, None)?.build()?), + ])?; + assert!( + !original.schema.field(0).is_nullable(), + "sanity: starting schema field is NOT NULL" + ); + + // Simulate a rewrite that made the inputs nullable while leaving + // the Union's cached schema stale. + let stale = LogicalPlan::Union(Union { + inputs: vec![ + Arc::new(table_scan(Some("t1"), &schema_nullable, None)?.build()?), + Arc::new(table_scan(Some("t2"), &schema_nullable, None)?.build()?), + ], + schema: Arc::clone(&original.schema), + }); + + let recomputed = stale.recompute_schema()?; + + assert!( + recomputed.schema().field(0).is_nullable(), + "Union schema should reflect the new nullable inputs after \ + recompute_schema(), but the stale NOT NULL schema was kept" + ); + + Ok(()) + } + + #[test] + fn test_recompute_schema_union_metadata_preservation() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + use std::collections::HashMap; + + let mut meta1 = HashMap::new(); + meta1.insert("k1".to_string(), "v1".to_string()); + let mut meta2 = HashMap::new(); + meta2.insert("k1".to_string(), "v2".to_string()); // duplicate key, different value + meta2.insert("k2".to_string(), "v2".to_string()); + + let schema1 = Schema::new_with_metadata( + vec![Field::new("a", DataType::Int32, false)], + meta1.clone(), + ); + let schema2 = Schema::new_with_metadata( + vec![Field::new("a", DataType::Int32, false)], + meta2.clone(), + ); + + // Build a Union. Its initial schema will have intersected metadata. + let original = Union::try_new(vec![ + Arc::new(table_scan(Some("t1"), &schema1, None)?.build()?), + Arc::new(table_scan(Some("t2"), &schema2, None)?.build()?), + ])?; + + // Union::try_new uses intersection, so k1 should be missing (v1 != v2) + // and k2 should be missing (not in meta1). + assert!(original.schema.metadata().is_empty()); + + // Now simulate recompute_schema() where we want EXTEND semantics (later takes precedence). + // Our implementation of recompute_schema for Union now does this. + let stale = LogicalPlan::Union(Union { + inputs: vec![ + Arc::new(table_scan(Some("t1"), &schema1, None)?.build()?), + Arc::new(table_scan(Some("t2"), &schema2, None)?.build()?), + ], + // Use a dummy schema that forces recomputation (e.g. different name) + schema: Arc::new(DFSchema::try_from(Schema::new(vec![Field::new( + "wrong_name", + DataType::Int32, + false, + )]))?), + }); + + let recomputed = stale.recompute_schema()?; + + // Metadata should now be {k1: v2, k2: v2} because meta2 was the last input. + assert_eq!(recomputed.schema().metadata().get("k1").unwrap(), "v2"); + assert_eq!(recomputed.schema().metadata().get("k2").unwrap(), "v2"); + + Ok(()) + } + + #[test] + fn test_recompute_schema_union_after_input_rewrite() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::tree_node::{Transformed, TreeNode}; + + // Build a Union over two Int32 table scans. + let schema_i32 = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema_i64 = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + + let union_plan = LogicalPlan::Union(Union::try_new(vec![ + Arc::new(table_scan(Some("t1"), &schema_i32, None)?.build()?), + Arc::new(table_scan(Some("t2"), &schema_i32, None)?.build()?), + ])?); + + // Sanity check: the Union schema starts as Int32. + assert_eq!(union_plan.schema().field(0).data_type(), &DataType::Int32); + + // Simulate what an optimizer pass does: rewrite all leaf nodes + // (TableScan) to use Int64, then call recompute_schema() on the way up. + // This is the pattern used by type_coercion and optimize_projections. + let rewritten = union_plan + .transform(|plan| match plan { + LogicalPlan::TableScan(ref scan) + if scan.source.schema().field(0).data_type() == &DataType::Int32 => + { + let new_scan = + table_scan(Some(scan.table_name.table()), &schema_i64, None)? + .build()?; + Ok(Transformed::yes(new_scan)) + } + other => Ok(Transformed::no(other)), + })? + .data; + + // After tree transformation, call recompute_schema() on the Union. + // Before this fix, the width-only check would leave the schema as Int32. + let fixed = rewritten.recompute_schema()?; + + assert_eq!( + fixed.schema().field(0).data_type(), + &DataType::Int64, + "recompute_schema() must update Union schema after input types change" + ); + + Ok(()) + } }