diff --git a/java/lance-jni/src/mem_wal.rs b/java/lance-jni/src/mem_wal.rs index 9ba3fdd7440..b8d3d0ab49f 100644 --- a/java/lance-jni/src/mem_wal.rs +++ b/java/lance-jni/src/mem_wal.rs @@ -1075,6 +1075,7 @@ fn shard_snapshot_from_manifest(manifest: ShardManifest) -> ShardSnapshot { path: generation.path, }) .collect(), + shard_field_values: Default::default(), } } diff --git a/python/src/mem_wal.rs b/python/src/mem_wal.rs index 25127c95ea4..fdb1dd2fc27 100644 --- a/python/src/mem_wal.rs +++ b/python/src/mem_wal.rs @@ -957,6 +957,7 @@ fn shard_snapshot_from_manifest(manifest: lance_index::mem_wal::ShardManifest) - path: generation.path, }) .collect(), + shard_field_values: Default::default(), } } diff --git a/rust/lance/src/dataset/mem_wal/scanner.rs b/rust/lance/src/dataset/mem_wal/scanner.rs index b1766f8525f..ebcf98c5197 100644 --- a/rust/lance/src/dataset/mem_wal/scanner.rs +++ b/rust/lance/src/dataset/mem_wal/scanner.rs @@ -41,6 +41,7 @@ mod fts_search; mod planner; mod point_lookup; mod projection; +pub(crate) mod shard_pruning; mod vector_search; pub use builder::LsmScanner; diff --git a/rust/lance/src/dataset/mem_wal/scanner/builder.rs b/rust/lance/src/dataset/mem_wal/scanner/builder.rs index ade4164d485..c5fab65f644 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/builder.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/builder.rs @@ -17,6 +17,7 @@ use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; use datafusion::prelude::{Expr, SessionContext}; use futures::TryStreamExt; use lance_core::{Error, Result, is_system_column}; +use lance_index::mem_wal::ShardingSpec; use uuid::Uuid; use super::collector::{InMemoryMemTableRef, InMemoryMemTables, LsmDataSourceCollector}; @@ -125,6 +126,10 @@ pub struct LsmScanner { /// Cache of opened flushed-generation datasets. When set, repeated /// queries against the same generation skip the manifest read entirely. flushed_cache: Option>, + /// Optional sharding spec for read-path shard pruning. + sharding_spec: Option, + /// Mapping from source field id to column name, for sharding evaluation. + source_id_to_column: HashMap, } impl LsmScanner { @@ -160,6 +165,8 @@ impl LsmScanner { pk_columns, session, flushed_cache: None, + sharding_spec: None, + source_id_to_column: HashMap::new(), } } @@ -198,6 +205,8 @@ impl LsmScanner { pk_columns, session: None, flushed_cache: None, + sharding_spec: None, + source_id_to_column: HashMap::new(), } } @@ -253,6 +262,19 @@ impl LsmScanner { self } + /// Set the sharding spec and source-column mapping for read-path shard + /// pruning. When set, the scan planner can skip shards whose field values + /// do not match the query filter. + pub fn with_sharding_spec( + mut self, + spec: ShardingSpec, + source_id_to_column: HashMap, + ) -> Self { + self.sharding_spec = Some(spec); + self.source_id_to_column = source_id_to_column; + self + } + /// Project specific columns. /// /// If not called, all columns from the base schema are included. @@ -490,6 +512,12 @@ impl LsmScanner { collector = collector.with_in_memory_memtables(*shard_id, mems.clone()); } + if let Some(spec) = &self.sharding_spec { + collector = collector + .with_sharding_spec(spec.clone(), self.source_id_to_column.clone()) + .with_base_schema(self.schema.clone()); + } + collector } } diff --git a/rust/lance/src/dataset/mem_wal/scanner/collector.rs b/rust/lance/src/dataset/mem_wal/scanner/collector.rs index 2db4b4f277d..f80b2ae315c 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/collector.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/collector.rs @@ -8,7 +8,9 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use arrow_schema::SchemaRef; +use datafusion::prelude::Expr; use lance_core::Result; +use lance_index::mem_wal::ShardingSpec; use uuid::Uuid; use super::data_source::{LsmDataSource, LsmGeneration, ShardSnapshot}; @@ -66,6 +68,12 @@ pub struct LsmDataSourceCollector { shard_snapshots: Vec, /// In-memory memtables by shard (active + frozen-awaiting-flush). in_memory_memtables: HashMap, + /// Optional sharding spec for read-path shard pruning. + sharding_spec: Option, + /// Mapping from source field id to column name, for sharding evaluation. + source_id_to_column: HashMap, + /// Base schema for type coercion during shard pruning. + base_schema: Option, } impl LsmDataSourceCollector { @@ -84,6 +92,9 @@ impl LsmDataSourceCollector { base_path, shard_snapshots, in_memory_memtables: HashMap::new(), + sharding_spec: None, + source_id_to_column: HashMap::new(), + base_schema: None, } } @@ -101,6 +112,9 @@ impl LsmDataSourceCollector { base_path: base_path.into().trim_end_matches('/').to_string(), shard_snapshots, in_memory_memtables: HashMap::new(), + sharding_spec: None, + source_id_to_column: HashMap::new(), + base_schema: None, } } @@ -132,6 +146,25 @@ impl LsmDataSourceCollector { self } + /// Set the sharding spec and source-column mapping for read-path shard + /// pruning. When set, [`Self::collect_pruned`] can skip shards whose + /// field values do not match the query filter. + pub fn with_sharding_spec( + mut self, + spec: ShardingSpec, + source_id_to_column: HashMap, + ) -> Self { + self.sharding_spec = Some(spec); + self.source_id_to_column = source_id_to_column; + self + } + + /// Set the base schema used for type coercion during shard pruning. + pub fn with_base_schema(mut self, schema: SchemaRef) -> Self { + self.base_schema = Some(schema); + self + } + /// Get the base table, if any. pub fn base_table(&self) -> Option<&Arc> { self.base_table.as_ref() @@ -304,6 +337,33 @@ impl LsmDataSourceCollector { Ok(sources) } + /// Collect data sources, pruning shards when the filter references the + /// sharding column and a [`ShardingSpec`] has been configured via + /// [`Self::with_sharding_spec`]. + /// + /// Falls back to [`Self::collect`] when pruning is not possible (no spec, + /// no filter, or the filter does not match the sharding column). + pub fn collect_pruned(&self, filter: Option<&Expr>) -> Result> { + if let Some(spec) = &self.sharding_spec + && let Some(filter) = filter + && let Some(shard_ids) = super::shard_pruning::prune_shards( + filter, + spec, + &self.shard_snapshots, + &self.source_id_to_column, + self.base_schema.as_ref(), + ) + { + tracing::debug!( + pruned_to = shard_ids.len(), + total = self.shard_snapshots.len(), + "shard pruning applied" + ); + return self.collect_for_shards(&shard_ids); + } + self.collect() + } + /// Get the total number of data sources. pub fn num_sources(&self) -> usize { let flushed_count: usize = self @@ -353,6 +413,7 @@ mod tests { path: "def_gen_2".to_string(), }, ], + shard_field_values: HashMap::new(), }, ShardSnapshot { shard_id: shard_b, @@ -362,6 +423,7 @@ mod tests { generation: 1, path: "xyz_gen_1".to_string(), }], + shard_field_values: HashMap::new(), }, ] } diff --git a/rust/lance/src/dataset/mem_wal/scanner/data_source.rs b/rust/lance/src/dataset/mem_wal/scanner/data_source.rs index 1a6207f27e3..b70276c8287 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/data_source.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/data_source.rs @@ -3,6 +3,7 @@ //! Data source types for LSM scanner. +use std::collections::HashMap; use std::sync::Arc; use arrow_schema::SchemaRef; @@ -93,6 +94,9 @@ pub struct ShardSnapshot { pub current_generation: u64, /// List of flushed generations and their paths. pub flushed_generations: Vec, + /// Computed shard field values, keyed by field id (e.g. bucket id). + /// Used by shard pruning to skip shards that cannot contain matching rows. + pub shard_field_values: HashMap>, } impl ShardSnapshot { @@ -103,6 +107,7 @@ impl ShardSnapshot { spec_id: 0, current_generation: 1, flushed_generations: Vec::new(), + shard_field_values: HashMap::new(), } } @@ -124,6 +129,16 @@ impl ShardSnapshot { .push(FlushedGeneration { generation, path }); self } + + /// Set the shard field values for this snapshot. + /// + /// These are the computed sharding-field values for the shard, keyed by + /// field id (e.g. the bucket number). Used by the read-path shard pruning + /// to skip shards whose field values do not match the query filter. + pub fn with_shard_field_values(mut self, values: HashMap>) -> Self { + self.shard_field_values = values; + self + } } /// A data source in the LSM tree that can be scanned. diff --git a/rust/lance/src/dataset/mem_wal/scanner/planner.rs b/rust/lance/src/dataset/mem_wal/scanner/planner.rs index f3f15e2e680..7ecb6da8177 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/planner.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/planner.rs @@ -106,8 +106,8 @@ impl LsmScanPlanner { .unwrap_or(false); let keep_row_address = keep_row_address || user_wants_rowaddr; - // 1. Collect all data sources - let sources = self.collector.collect()?; + // 1. Collect all data sources (with shard pruning when possible) + let sources = self.collector.collect_pruned(filter)?; if sources.is_empty() { // Return empty plan @@ -1806,4 +1806,302 @@ mod integration_tests { "rejection message should mention the offending column, got: {msg}", ); } + + /// Shard pruning: a query with a filter matching the sharding column reads + /// only the targeted shard, skipping data in other shards. + /// + /// Layout: + /// - Two shards (A and B), bucketed on `id` with 4 buckets. + /// - Shard A's bucket contains id=1, shard B's bucket contains id=2. + /// - Filter `id = 1` prunes shard B. + #[tokio::test] + async fn test_shard_pruning_skips_non_matching_shards() { + use crate::dataset::mem_wal::sharding::hash_scalar_to_bucket; + use datafusion::common::ScalarValue; + use lance_index::mem_wal::{ShardingField, ShardingSpec}; + + let schema = create_pk_schema(); + let temp_dir = tempfile::tempdir().unwrap(); + let base_path = temp_dir.path().to_str().unwrap(); + let base_uri = format!("{}/base", base_path); + + let shard_a = Uuid::new_v4(); + let shard_b = Uuid::new_v4(); + + // Compute bucket ids for id=1 and id=2 with num_buckets=4. + let bucket_a = hash_scalar_to_bucket(&ScalarValue::Int32(Some(1)), 4).unwrap(); + let bucket_b = hash_scalar_to_bucket(&ScalarValue::Int32(Some(2)), 4).unwrap(); + // Ensure distinct buckets to have a meaningful test. + assert_ne!( + bucket_a, bucket_b, + "id=1 and id=2 must hash to different buckets for this test" + ); + + // Shard A: active memtable with id=1, name="shard_a_1" + let store_a = Arc::new(BatchStore::with_capacity(16)); + store_a + .append(create_test_batch(&schema, &[1], "shard_a")) + .unwrap(); + let in_memory_a = InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store: store_a, + index_store: Arc::new(IndexStore::new()), + schema: schema.clone(), + generation: 1, + }, + frozen: vec![], + }; + + // Shard B: active memtable with id=2, name="shard_b_2" + let store_b = Arc::new(BatchStore::with_capacity(16)); + store_b + .append(create_test_batch(&schema, &[2], "shard_b")) + .unwrap(); + let in_memory_b = InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store: store_b, + index_store: Arc::new(IndexStore::new()), + schema: schema.clone(), + generation: 1, + }, + frozen: vec![], + }; + + // Shard snapshots with shard_field_values populated. + let snapshot_a = ShardSnapshot::new(shard_a) + .with_current_generation(1) + .with_shard_field_values(std::collections::HashMap::from([( + "bucket".to_string(), + bucket_a.to_le_bytes().to_vec(), + )])); + let snapshot_b = ShardSnapshot::new(shard_b) + .with_current_generation(1) + .with_shard_field_values(std::collections::HashMap::from([( + "bucket".to_string(), + bucket_b.to_le_bytes().to_vec(), + )])); + + let sharding_spec = ShardingSpec { + spec_id: 1, + fields: vec![ShardingField { + field_id: "bucket".to_string(), + source_ids: vec![], + transform: Some("bucket".to_string()), + expression: None, + result_type: "int32".to_string(), + parameters: std::collections::HashMap::from([ + ("num_buckets".to_string(), "4".to_string()), + ("column".to_string(), "id".to_string()), + ]), + }], + }; + + // Build scanner with sharding spec -- filter `id = 1` + let scanner = LsmScanner::without_base_table( + schema.clone(), + base_uri.clone(), + vec![snapshot_a.clone(), snapshot_b.clone()], + vec!["id".to_string()], + ) + .with_in_memory_memtables(shard_a, in_memory_a.clone()) + .with_in_memory_memtables(shard_b, in_memory_b.clone()) + .with_sharding_spec(sharding_spec.clone(), std::collections::HashMap::new()) + .filter("id = 1") + .unwrap(); + + let batches: Vec = scanner + .try_into_stream() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let mut results: HashMap = HashMap::new(); + for batch in &batches { + let ids = batch + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let names = batch + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + results.insert(ids.value(i), names.value(i).to_string()); + } + } + + // Only id=1 from shard_a should be returned. + assert_eq!(results.len(), 1, "expected 1 result, got {:?}", results); + assert_eq!( + results.get(&1), + Some(&"shard_a_1".to_string()), + "shard_a row must survive" + ); + assert!(!results.contains_key(&2), "shard_b row must be pruned"); + + // Verify that WITHOUT the sharding spec, both rows appear (no pruning). + let scanner_no_pruning = LsmScanner::without_base_table( + schema.clone(), + base_uri.clone(), + vec![snapshot_a, snapshot_b], + vec!["id".to_string()], + ) + .with_in_memory_memtables(shard_a, in_memory_a) + .with_in_memory_memtables(shard_b, in_memory_b) + .filter("id = 1") + .unwrap(); + + let batches_no_pruning: Vec = scanner_no_pruning + .try_into_stream() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + // Without pruning, the filter still returns only id=1, but both + // shards are scanned (one just happens to have no matching rows). + // The end result is the same -- but the key difference is in the + // number of data sources collected. We verify the collector behavior + // separately below. + let mut results_no_pruning: HashMap = HashMap::new(); + for batch in &batches_no_pruning { + let ids = batch + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let names = batch + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + results_no_pruning.insert(ids.value(i), names.value(i).to_string()); + } + } + assert_eq!( + results_no_pruning.len(), + 1, + "filter still returns 1 row without pruning" + ); + + // Direct collector-level verification is done in the separate + // test_collect_pruned_reduces_sources test below. + } + + /// Verify that `collect_pruned` actually reduces the number of + /// in-memory sources when sharding spec is configured and filter matches. + #[tokio::test] + async fn test_collect_pruned_reduces_sources() { + use crate::dataset::mem_wal::scanner::collector::LsmDataSourceCollector; + use crate::dataset::mem_wal::sharding::hash_scalar_to_bucket; + use datafusion::common::ScalarValue; + use lance_index::mem_wal::{ShardingField, ShardingSpec}; + + let schema = create_pk_schema(); + + let shard_a = Uuid::new_v4(); + let shard_b = Uuid::new_v4(); + + let bucket_a = hash_scalar_to_bucket(&ScalarValue::Int32(Some(1)), 4).unwrap(); + let bucket_b = hash_scalar_to_bucket(&ScalarValue::Int32(Some(2)), 4).unwrap(); + assert_ne!(bucket_a, bucket_b); + + let mk_memtable = |ids: &[i32], generation: u64| { + let store = Arc::new(BatchStore::with_capacity(16)); + store.append(create_test_batch(&schema, ids, "v")).unwrap(); + InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store: store, + index_store: Arc::new(IndexStore::new()), + schema: schema.clone(), + generation, + }, + frozen: vec![], + } + }; + + let sharding_spec = ShardingSpec { + spec_id: 1, + fields: vec![ShardingField { + field_id: "bucket".to_string(), + source_ids: vec![], + transform: Some("bucket".to_string()), + expression: None, + result_type: "int32".to_string(), + parameters: std::collections::HashMap::from([ + ("num_buckets".to_string(), "4".to_string()), + ("column".to_string(), "id".to_string()), + ]), + }], + }; + + let collector = LsmDataSourceCollector::without_base_table( + "memory:///test", + vec![ + ShardSnapshot::new(shard_a) + .with_current_generation(1) + .with_shard_field_values(std::collections::HashMap::from([( + "bucket".to_string(), + bucket_a.to_le_bytes().to_vec(), + )])), + ShardSnapshot::new(shard_b) + .with_current_generation(1) + .with_shard_field_values(std::collections::HashMap::from([( + "bucket".to_string(), + bucket_b.to_le_bytes().to_vec(), + )])), + ], + ) + .with_in_memory_memtables(shard_a, mk_memtable(&[1], 1)) + .with_in_memory_memtables(shard_b, mk_memtable(&[2], 1)) + .with_sharding_spec(sharding_spec, std::collections::HashMap::new()) + .with_base_schema(schema.clone()); + + // Without filter: both shards' memtables are collected. + let all = collector.collect().unwrap(); + assert_eq!(all.len(), 2, "both shards' memtables without pruning"); + + // With a filter on the sharding column: only matching shard collected. + let filter_expr = { + use datafusion::common::ToDFSchema; + let ctx = datafusion::prelude::SessionContext::new(); + let df_schema = schema.as_ref().clone().to_dfschema().unwrap(); + ctx.parse_sql_expr("id = 1", &df_schema).unwrap() + }; + let pruned = collector.collect_pruned(Some(&filter_expr)).unwrap(); + assert_eq!( + pruned.len(), + 1, + "only shard_a's memtable after pruning on id=1" + ); + assert_eq!( + pruned[0].shard_id(), + Some(shard_a), + "pruned source must be shard_a" + ); + + // Non-prunable filter (range predicate): falls back to all sources. + let range_filter = { + use datafusion::common::ToDFSchema; + let ctx = datafusion::prelude::SessionContext::new(); + let df_schema = schema.as_ref().clone().to_dfschema().unwrap(); + ctx.parse_sql_expr("id > 0", &df_schema).unwrap() + }; + let not_pruned = collector.collect_pruned(Some(&range_filter)).unwrap(); + assert_eq!( + not_pruned.len(), + 2, + "range filter cannot prune; both shards collected" + ); + } } diff --git a/rust/lance/src/dataset/mem_wal/scanner/shard_pruning.rs b/rust/lance/src/dataset/mem_wal/scanner/shard_pruning.rs new file mode 100644 index 00000000000..173a17e827a --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/scanner/shard_pruning.rs @@ -0,0 +1,539 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Read-path shard pruning for MemWAL queries. +//! +//! Given a query filter and a [`ShardingSpec`], this module determines which +//! shards can be skipped because their field values cannot match the filter. +//! When the filter contains an equality (`col = lit`) or `IN` predicate on +//! the sharding column, each literal value is evaluated through the sharding +//! transform (bucket / identity / unsharded) and the resulting set of target +//! shard IDs is intersected with the available shards. + +use std::collections::{HashMap, HashSet}; + +use arrow_schema::SchemaRef; +use datafusion::common::ScalarValue; +use datafusion::logical_expr::Operator; +use datafusion::prelude::Expr; +use lance_index::mem_wal::ShardingSpec; +use uuid::Uuid; + +use super::data_source::ShardSnapshot; +use crate::dataset::mem_wal::sharding::{hash_scalar_to_bucket, source_column_for_field}; + +/// Attempt to prune shards based on a query filter and the sharding spec. +/// +/// Returns `Some(shard_ids)` when the filter contains an equality or `IN` +/// predicate on the sharding column; only shards whose computed field values +/// match will be in the returned set. Returns `None` when pruning is not +/// possible (e.g. the filter does not reference the sharding column, or the +/// spec is unsharded). +/// +/// `base_schema` is used to coerce filter literals (e.g. SQL `Int64`) to the +/// column's actual Arrow type before hashing, so the bucket id matches the +/// one stored in the shard manifest. +pub fn prune_shards( + filter: &Expr, + spec: &ShardingSpec, + snapshots: &[ShardSnapshot], + source_id_to_column: &HashMap, + base_schema: Option<&SchemaRef>, +) -> Option> { + // We only prune single-field bucket/identity specs today. + let field = spec.fields.first()?; + let transform = field.transform.as_deref()?; + if transform == "unsharded" { + return None; // All rows go to shard 0; nothing to prune. + } + + let column_name = source_column_for_field(field, source_id_to_column).ok()?; + + // Extract literal values from the filter for the sharding column. + let literals = extract_column_literals(filter, &column_name)?; + if literals.is_empty() { + return None; + } + + // Coerce literals to the column's Arrow type so the hash matches what + // the write path stored. SQL parsing often produces Int64 for integer + // literals even when the column is Int32. + let coerced = coerce_literals(&literals, &column_name, base_schema); + + match transform { + "bucket" => { + let num_buckets: i32 = field.parameters.get("num_buckets")?.parse().ok()?; + if num_buckets <= 0 { + return None; + } + + // Compute bucket id for each literal value. + let mut target_bucket_bytes: HashSet> = HashSet::new(); + for lit in &coerced { + if let Some(bucket) = hash_scalar_to_bucket(lit, num_buckets) { + target_bucket_bytes.insert(bucket.to_le_bytes().to_vec()); + } + } + + let field_id = &field.field_id; + let matching: HashSet = snapshots + .iter() + .filter(|s| { + s.shard_field_values + .get(field_id) + .map(|v| target_bucket_bytes.contains(v)) + .unwrap_or(true) // If no field value recorded, don't prune. + }) + .map(|s| s.shard_id) + .collect(); + Some(matching) + } + "identity" => { + // Identity sharding: the shard field value IS the column value. + let mut target_values: HashSet> = HashSet::new(); + for lit in &coerced { + if let Some(bytes) = scalar_to_identity_bytes(lit) { + target_values.insert(bytes); + } + } + + let field_id = &field.field_id; + let matching: HashSet = snapshots + .iter() + .filter(|s| { + s.shard_field_values + .get(field_id) + .map(|v| target_values.contains(v)) + .unwrap_or(true) + }) + .map(|s| s.shard_id) + .collect(); + Some(matching) + } + _ => None, + } +} + +/// Extract literal values from equality (`col = lit`) or `IN (lit, ...)` +/// predicates on `column_name`. Returns `None` if the filter does not +/// reference the column in a prunable shape. +fn extract_column_literals(filter: &Expr, column_name: &str) -> Option> { + match filter { + // col = lit or lit = col + Expr::BinaryExpr(b) if matches!(b.op, Operator::Eq) => { + match (b.left.as_ref(), b.right.as_ref()) { + (Expr::Column(c), Expr::Literal(lit, _)) + | (Expr::Literal(lit, _), Expr::Column(c)) + if c.name == column_name => + { + Some(vec![lit.clone()]) + } + _ => None, + } + } + // col IN (lit, lit, ...) + Expr::InList(in_list) if !in_list.negated => { + let Expr::Column(c) = in_list.expr.as_ref() else { + return None; + }; + if c.name != column_name { + return None; + } + let mut vals = Vec::with_capacity(in_list.list.len()); + for e in &in_list.list { + let Expr::Literal(lit, _) = e else { + return None; + }; + vals.push(lit.clone()); + } + (!vals.is_empty()).then_some(vals) + } + // AND: recurse into both sides and union the results. + Expr::BinaryExpr(b) if matches!(b.op, Operator::And) => { + let left = extract_column_literals(&b.left, column_name); + let right = extract_column_literals(&b.right, column_name); + // For AND, if either side constrains the column, that's our match. + // If both do, use the intersection (fewer target shards). In + // practice, only one side of an AND constrains the same column. + left.or(right) + } + _ => None, + } +} + +/// Convert a [`ScalarValue`] to its Arrow little-endian byte representation +/// for identity sharding comparison. +fn scalar_to_identity_bytes(scalar: &ScalarValue) -> Option> { + match scalar { + ScalarValue::Int8(Some(v)) => Some(v.to_le_bytes().to_vec()), + ScalarValue::Int16(Some(v)) => Some(v.to_le_bytes().to_vec()), + ScalarValue::Int32(Some(v)) => Some(v.to_le_bytes().to_vec()), + ScalarValue::Int64(Some(v)) => Some(v.to_le_bytes().to_vec()), + ScalarValue::UInt8(Some(v)) => Some(v.to_le_bytes().to_vec()), + ScalarValue::UInt16(Some(v)) => Some(v.to_le_bytes().to_vec()), + ScalarValue::UInt32(Some(v)) => Some(v.to_le_bytes().to_vec()), + ScalarValue::UInt64(Some(v)) => Some(v.to_le_bytes().to_vec()), + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => Some(v.as_bytes().to_vec()), + _ => None, + } +} + +/// Coerce filter literals to match the column's Arrow type. SQL parsing +/// often produces `Int64` for integer literals even when the column is +/// `Int32`. Without coercion the Murmur3 hash would differ (Int32 vs Int64 +/// code paths) and shard pruning would silently fail to match. +fn coerce_literals( + literals: &[ScalarValue], + column_name: &str, + base_schema: Option<&SchemaRef>, +) -> Vec { + let Some(schema) = base_schema else { + return literals.to_vec(); + }; + let Ok(field) = schema.field_with_name(column_name) else { + return literals.to_vec(); + }; + let target_type = field.data_type(); + literals + .iter() + .map(|lit| { + if &lit.data_type() == target_type { + lit.clone() + } else { + lit.cast_to(target_type).unwrap_or_else(|_| lit.clone()) + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::prelude::{col, lit}; + use lance_index::mem_wal::ShardingField; + + fn bucket_spec(column: &str, num_buckets: i32) -> ShardingSpec { + ShardingSpec { + spec_id: 1, + fields: vec![ShardingField { + field_id: "bucket".to_string(), + source_ids: vec![], + transform: Some("bucket".to_string()), + expression: None, + result_type: "int32".to_string(), + parameters: HashMap::from([ + ("num_buckets".to_string(), num_buckets.to_string()), + ("column".to_string(), column.to_string()), + ]), + }], + } + } + + fn identity_spec(column: &str) -> ShardingSpec { + ShardingSpec { + spec_id: 1, + fields: vec![ShardingField { + field_id: "ident".to_string(), + source_ids: vec![], + transform: Some("identity".to_string()), + expression: None, + result_type: "utf8".to_string(), + parameters: HashMap::from([("column".to_string(), column.to_string())]), + }], + } + } + + fn snapshot_with_bucket(shard_id: Uuid, field_id: &str, bucket: i32) -> ShardSnapshot { + ShardSnapshot::new(shard_id).with_shard_field_values(HashMap::from([( + field_id.to_string(), + bucket.to_le_bytes().to_vec(), + )])) + } + + fn snapshot_with_identity(shard_id: Uuid, field_id: &str, value: &str) -> ShardSnapshot { + ShardSnapshot::new(shard_id).with_shard_field_values(HashMap::from([( + field_id.to_string(), + value.as_bytes().to_vec(), + )])) + } + + #[test] + fn test_bucket_pruning_equality() { + let spec = bucket_spec("region", 4); + let shard_a = Uuid::new_v4(); + let shard_b = Uuid::new_v4(); + let shard_c = Uuid::new_v4(); + + // Compute the bucket for "us-east" with 4 buckets. + let target_bucket = + hash_scalar_to_bucket(&ScalarValue::Utf8(Some("us-east".to_string())), 4).unwrap(); + + let snapshots = vec![ + snapshot_with_bucket(shard_a, "bucket", target_bucket), + snapshot_with_bucket(shard_b, "bucket", (target_bucket + 1) % 4), + snapshot_with_bucket(shard_c, "bucket", target_bucket), // same bucket as a + ]; + + let filter = col("region").eq(lit("us-east")); + let result = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), None); + + let pruned = result.expect("should prune"); + assert!(pruned.contains(&shard_a)); + assert!(!pruned.contains(&shard_b)); + assert!(pruned.contains(&shard_c)); + } + + #[test] + fn test_bucket_pruning_in_list() { + let spec = bucket_spec("id", 8); + let shard_a = Uuid::new_v4(); + let shard_b = Uuid::new_v4(); + + let bucket_for_1 = hash_scalar_to_bucket(&ScalarValue::Int32(Some(1)), 8).unwrap(); + let bucket_for_2 = hash_scalar_to_bucket(&ScalarValue::Int32(Some(2)), 8).unwrap(); + + // Make shard_a match bucket_for_1, shard_b a different bucket. + let other_bucket = (0..8) + .find(|b| *b != bucket_for_1 && *b != bucket_for_2) + .unwrap(); + let snapshots = vec![ + snapshot_with_bucket(shard_a, "bucket", bucket_for_1), + snapshot_with_bucket(shard_b, "bucket", other_bucket), + ]; + + let filter = col("id").in_list(vec![lit(1i32), lit(2i32)], false); + let result = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), None); + + let pruned = result.expect("should prune"); + assert!(pruned.contains(&shard_a)); + // shard_b has a bucket that matches neither 1 nor 2 + assert!(!pruned.contains(&shard_b)); + } + + #[test] + fn test_identity_pruning() { + let spec = identity_spec("tenant"); + let shard_a = Uuid::new_v4(); + let shard_b = Uuid::new_v4(); + + let snapshots = vec![ + snapshot_with_identity(shard_a, "ident", "acme"), + snapshot_with_identity(shard_b, "ident", "globex"), + ]; + + let filter = col("tenant").eq(lit("acme")); + let result = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), None); + let pruned = result.expect("should prune"); + assert!(pruned.contains(&shard_a)); + assert!(!pruned.contains(&shard_b)); + } + + #[test] + fn test_no_pruning_for_unsharded() { + let spec = ShardingSpec { + spec_id: 1, + fields: vec![ShardingField { + field_id: "u".to_string(), + source_ids: vec![], + transform: Some("unsharded".to_string()), + expression: None, + result_type: "int32".to_string(), + parameters: HashMap::new(), + }], + }; + let filter = col("x").eq(lit(1i32)); + assert!(prune_shards(&filter, &spec, &[], &HashMap::new(), None).is_none()); + } + + #[test] + fn test_no_pruning_for_non_sharding_column() { + let spec = bucket_spec("region", 4); + // Filter is on "name", not "region". + let filter = col("name").eq(lit("foo")); + assert!(prune_shards(&filter, &spec, &[], &HashMap::new(), None).is_none()); + } + + #[test] + fn test_snapshot_without_field_values_not_pruned() { + let spec = bucket_spec("id", 4); + let shard_a = Uuid::new_v4(); + let shard_b = Uuid::new_v4(); + + let bucket_for_1 = hash_scalar_to_bucket(&ScalarValue::Int32(Some(1)), 4).unwrap(); + let snapshots = vec![ + snapshot_with_bucket(shard_a, "bucket", bucket_for_1), + ShardSnapshot::new(shard_b), // no field values -- must NOT be pruned + ]; + + let filter = col("id").eq(lit(1i32)); + let result = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), None); + let pruned = result.expect("should prune"); + assert!(pruned.contains(&shard_a)); + assert!(pruned.contains(&shard_b)); // kept because no field values + } + + #[test] + fn test_and_conjunction_extracts_sharding_column() { + let spec = bucket_spec("id", 8); + let shard_a = Uuid::new_v4(); + let shard_b = Uuid::new_v4(); + + let bucket_for_1 = hash_scalar_to_bucket(&ScalarValue::Int32(Some(1)), 8).unwrap(); + let other_bucket = (0..8).find(|b| *b != bucket_for_1).unwrap(); + + let snapshots = vec![ + snapshot_with_bucket(shard_a, "bucket", bucket_for_1), + snapshot_with_bucket(shard_b, "bucket", other_bucket), + ]; + + // Filter: id = 1 AND name = "foo" + // Only id = 1 should be extracted; name is not the sharding column. + let filter = col("id").eq(lit(1i32)).and(col("name").eq(lit("foo"))); + let result = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), None); + + let pruned = result.expect("should prune using id = 1 from AND"); + assert!(pruned.contains(&shard_a)); + assert!(!pruned.contains(&shard_b)); + } + + #[test] + fn test_or_disjunction_returns_none() { + let spec = bucket_spec("id", 8); + let shard_a = Uuid::new_v4(); + + let bucket_for_1 = hash_scalar_to_bucket(&ScalarValue::Int32(Some(1)), 8).unwrap(); + let snapshots = vec![snapshot_with_bucket(shard_a, "bucket", bucket_for_1)]; + + // Filter: id = 1 OR id = 2 -- OR is not handled, should return None. + let filter = col("id").eq(lit(1i32)).or(col("id").eq(lit(2i32))); + let result = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), None); + + assert!(result.is_none(), "OR filters should not be prunable"); + } + + #[test] + fn test_not_in_returns_none() { + let spec = bucket_spec("id", 8); + let shard_a = Uuid::new_v4(); + + let bucket_for_1 = hash_scalar_to_bucket(&ScalarValue::Int32(Some(1)), 8).unwrap(); + let snapshots = vec![snapshot_with_bucket(shard_a, "bucket", bucket_for_1)]; + + // Filter: id NOT IN (1, 2) -- negated InList should return None. + let filter = col("id").in_list(vec![lit(1i32), lit(2i32)], true); + let result = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), None); + + assert!(result.is_none(), "NOT IN should not be prunable"); + } + + #[test] + fn test_type_coercion_int64_to_int32() { + use std::sync::Arc; + + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + + let num_buckets = 8; + let spec = bucket_spec("id", num_buckets); + let shard_a = Uuid::new_v4(); + let shard_b = Uuid::new_v4(); + + // Snapshots store bucket values computed from Int32(1) -- matches write path. + let bucket_for_int32 = + hash_scalar_to_bucket(&ScalarValue::Int32(Some(1)), num_buckets).unwrap(); + let bucket_for_int64 = + hash_scalar_to_bucket(&ScalarValue::Int64(Some(1)), num_buckets).unwrap(); + // Precondition: Int32 and Int64 must hash to different buckets for this + // test to be meaningful. Murmur3 uses hash_int vs hash_long code paths. + assert_ne!( + bucket_for_int32, bucket_for_int64, + "Int32 and Int64 should hash to different buckets for value 1" + ); + + let snapshots = vec![ + snapshot_with_bucket(shard_a, "bucket", bucket_for_int32), + snapshot_with_bucket(shard_b, "bucket", bucket_for_int64), + ]; + + // Filter uses Int64 literal (as SQL parsing typically produces). + let filter = col("id").eq(lit(1i64)); + + // WITH base_schema: Int64 coerced to Int32 before hashing -> matches shard_a. + let schema: SchemaRef = Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int32, + false, + )])); + let with_schema = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), Some(&schema)); + let pruned = with_schema.expect("coercion should enable pruning"); + assert!( + pruned.contains(&shard_a), + "with coercion, Int64 -> Int32 should match the Int32-hashed shard" + ); + assert!( + !pruned.contains(&shard_b), + "with coercion, should not match the Int64-hashed shard" + ); + + // WITHOUT base_schema: Int64 stays Int64, hashes differently -> misses shard_a. + let without_schema = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), None); + let pruned_raw = without_schema.expect("should still return Some (literals extracted)"); + assert!( + !pruned_raw.contains(&shard_a), + "without coercion, Int64 hash differs from Int32 -- should miss shard_a" + ); + assert!( + pruned_raw.contains(&shard_b), + "without coercion, Int64 hash matches shard_b (stored with Int64 bucket)" + ); + } + + #[test] + fn test_multi_field_spec_uses_first_field() { + let spec = ShardingSpec { + spec_id: 1, + fields: vec![ + ShardingField { + field_id: "bucket".to_string(), + source_ids: vec![], + transform: Some("bucket".to_string()), + expression: None, + result_type: "int32".to_string(), + parameters: HashMap::from([ + ("num_buckets".to_string(), "4".to_string()), + ("column".to_string(), "id".to_string()), + ]), + }, + ShardingField { + field_id: "second_field".to_string(), + source_ids: vec![], + transform: Some("bucket".to_string()), + expression: None, + result_type: "int32".to_string(), + parameters: HashMap::from([ + ("num_buckets".to_string(), "4".to_string()), + ("column".to_string(), "region".to_string()), + ]), + }, + ], + }; + + let shard_a = Uuid::new_v4(); + let bucket_for_1 = hash_scalar_to_bucket(&ScalarValue::Int32(Some(1)), 4).unwrap(); + let snapshots = vec![snapshot_with_bucket(shard_a, "bucket", bucket_for_1)]; + + // Filter on the first field's column -- should work without panic. + let filter = col("id").eq(lit(1i32)); + let result = prune_shards(&filter, &spec, &snapshots, &HashMap::new(), None); + let pruned = result.expect("should prune using first field"); + assert!(pruned.contains(&shard_a)); + } + + #[test] + fn test_empty_snapshots_returns_empty_set() { + let spec = bucket_spec("id", 4); + let filter = col("id").eq(lit(1i32)); + let result = prune_shards(&filter, &spec, &[], &HashMap::new(), None); + let pruned = result.expect("should return Some even with empty snapshots"); + assert!(pruned.is_empty(), "empty snapshots should yield empty set"); + } +} diff --git a/rust/lance/src/dataset/mem_wal/sharding.rs b/rust/lance/src/dataset/mem_wal/sharding.rs index 5982ce99ee9..2a2ba32d9eb 100644 --- a/rust/lance/src/dataset/mem_wal/sharding.rs +++ b/rust/lance/src/dataset/mem_wal/sharding.rs @@ -199,6 +199,46 @@ fn source_column_name( }) } +/// Resolve a sharding field's source column name from the field-id-to-column mapping. +pub fn source_column_for_field( + field: &ShardingField, + source_id_to_column: &HashMap, +) -> Result { + source_column_name(field, source_id_to_column) +} + +/// Compute the bucket id for a single scalar value using the same Murmur3 +/// hash that the bucket sharding transform applies per-row. Returns `None` if +/// the scalar type is unsupported or `num_buckets` is non-positive. +pub fn hash_scalar_to_bucket( + scalar: &datafusion::common::ScalarValue, + num_buckets: i32, +) -> Option { + use datafusion::common::ScalarValue; + if num_buckets <= 0 { + return None; + } + let hash = match scalar { + ScalarValue::Boolean(Some(v)) => hash_int(if *v { 1 } else { 0 }, MURMUR3_SEED), + ScalarValue::Int8(Some(v)) => hash_int(*v as i32, MURMUR3_SEED), + ScalarValue::Int16(Some(v)) => hash_int(*v as i32, MURMUR3_SEED), + ScalarValue::Int32(Some(v)) => hash_int(*v, MURMUR3_SEED), + ScalarValue::Int64(Some(v)) => hash_long(*v, MURMUR3_SEED), + ScalarValue::UInt8(Some(v)) => hash_int(*v as i32, MURMUR3_SEED), + ScalarValue::UInt16(Some(v)) => hash_int(*v as i32, MURMUR3_SEED), + ScalarValue::UInt32(Some(v)) => hash_int(*v as i32, MURMUR3_SEED), + ScalarValue::UInt64(Some(v)) => hash_long(*v as i64, MURMUR3_SEED), + ScalarValue::Float32(Some(v)) => hash_int(canonical_f32_bits(*v) as i32, MURMUR3_SEED), + ScalarValue::Float64(Some(v)) => hash_long(canonical_f64_bits(*v) as i64, MURMUR3_SEED), + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => { + hash_bytes(v.as_bytes(), MURMUR3_SEED) + } + ScalarValue::Date32(Some(v)) => hash_int(*v, MURMUR3_SEED), + _ => return None, // NULL or unsupported type + }; + Some((hash & i32::MAX) % num_buckets) +} + fn hash_array_value(array: &dyn Array, row_idx: usize, seed: i32) -> Result { if array.is_null(row_idx) { return Ok(seed);