From 707da7e7de2224ad3f7cd8a51d2b2e429585bb98 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 4 Jun 2026 00:11:56 +0800 Subject: [PATCH 01/14] feat(index): support raw-query ivf rq search --- docs/src/format/index/vector/index.md | 8 +- python/python/tests/test_vector_index.py | 33 +- rust/lance-index/src/vector/bq.rs | 17 +- rust/lance-index/src/vector/bq/builder.rs | 110 ++-- rust/lance-index/src/vector/bq/storage.rs | 572 ++++++++++++++++-- rust/lance-index/src/vector/bq/transform.rs | 227 ++++--- .../src/vector/distributed/index_merger.rs | 4 +- rust/lance-index/src/vector/storage.rs | 16 +- rust/lance/src/index/vector.rs | 13 +- .../src/index/vector/ivf/partition_serde.rs | 11 +- rust/lance/src/index/vector/ivf/v2.rs | 227 +++++-- 11 files changed, 976 insertions(+), 262 deletions(-) diff --git a/docs/src/format/index/vector/index.md b/docs/src/format/index/vector/index.md index 3b209934f64..48bd27163a0 100644 --- a/docs/src/format/index/vector/index.md +++ b/docs/src/format/index/vector/index.md @@ -254,6 +254,7 @@ For **RabitQ (RQ)**: | `num_bits` | u8 | Number of bits per dimension, in the range 1..=9 | | `code_dim` | u32 | Rotated vector dimension for the 1-bit binary code | | `packed` | bool | Whether codes are packed for optimized computation | +| `query_estimator` | string | Distance estimator layout: `residual_query` or `raw_query`. Missing values are read as `residual_query` for compatibility with released 1-bit IVF_RQ indexes. | #### Lance File Global Buffer @@ -279,8 +280,9 @@ to rotate vectors before binary quantization: The rotation matrix has shape `[code_dim, code_dim]` where `code_dim` is the rotated vector dimension. IVF_RQ always stores the 1-bit binary sign code in `_rabit_codes`; for `num_bits > 1`, the remaining `num_bits - 1` ex-code bits are stored in `__ex_codes` instead of widening the -binary code path. `num_bits=1` indexes only store the binary-code factor columns; multi-bit indexes -also store separate ex-code additive and scale factors. +binary code path. New IVF_RQ indexes store raw-query estimator factors. `num_bits=1` indexes only +store the binary-code factor columns; multi-bit indexes also store separate ex-code additive and +scale factors. ## Appendices @@ -345,7 +347,7 @@ auxiliary schema also includes `__ex_codes`, `__add_factors_ex`, and `__scale_fa - Arrow Schema Metadata: - `"distance_type"` → `"l2"` - `"lance:ivf"` → tracks per-partition `offsets` and `lengths` (no centroids here) - - `"lance:rabit"` → `"{"rotate_mat_position":1,"num_bits":1,"packed":true}"` + - `"lance:rabit"` → `"{"rotate_mat_position":1,"num_bits":1,"packed":true,"query_estimator":"raw_query"}"` - Lance File Global buffer: - `Tensor` rotation matrix with shape `[code_dim, code_dim]` = `[128, 128]` (float32) - Rows with Arrow schema: diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 505d20798bb..2760d55e842 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1067,13 +1067,7 @@ def test_create_ivf_rq_skip_transpose(): assert stats["indices"][0]["sub_index"]["packed"] is False -@pytest.mark.skip( - reason=( - "IVF_RQ num_bits>1 creation is gated until split-code search support " - "is implemented" - ) -) -def test_create_ivf_rq_multi_bit_gates_search(): +def test_create_ivf_rq_multi_bit_searches_l2_and_gates_cosine(): ds = lance.write_dataset(create_table(), "memory://") ds = ds.create_index( @@ -1084,14 +1078,25 @@ def test_create_ivf_rq_multi_bit_gates_search(): ) stats = ds.stats.index_stats("vector_idx") assert stats["indices"][0]["sub_index"]["num_bits"] == 9 + assert stats["indices"][0]["sub_index"]["query_estimator"] == "raw_query" - with pytest.raises(pa.ArrowInvalid, match="num_bits>1 search is not supported"): - ds.to_table( - nearest={ - "column": "vector", - "q": np.random.randn(128).astype(np.float32), - "k": 10, - } + result = ds.to_table( + nearest={ + "column": "vector", + "q": np.random.randn(128).astype(np.float32), + "k": 10, + } + ) + assert result.num_rows == 10 + + cosine_ds = lance.write_dataset(create_table(), "memory://") + with pytest.raises(NotImplementedError, match="num_bits>1 cosine index creation"): + cosine_ds.create_index( + "vector", + index_type="IVF_RQ", + metric="cosine", + num_partitions=4, + num_bits=9, ) diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index 8a347f48817..0fdd918edab 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -128,14 +128,7 @@ pub fn validate_rq_num_bits(num_bits: u8) -> Result<()> { } pub fn validate_supported_rq_num_bits(num_bits: u8) -> Result<()> { - validate_rq_num_bits(num_bits)?; - if num_bits != RABIT_BINARY_NUM_BITS { - return Err(Error::not_supported(format!( - "IVF_RQ num_bits={} index creation is not supported until split-code search support is implemented", - num_bits - ))); - } - Ok(()) + validate_rq_num_bits(num_bits) } pub fn rabit_ex_bits(num_bits: u8) -> Result { @@ -261,13 +254,7 @@ mod tests { ); validate_supported_rq_num_bits(1).unwrap(); - let err = validate_supported_rq_num_bits(9).unwrap_err(); - assert!( - err.to_string() - .contains("num_bits=9 index creation is not supported"), - "{}", - err - ); + validate_supported_rq_num_bits(9).unwrap(); } #[test] diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index df6e6591299..8f4f3bc03bd 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -18,7 +18,7 @@ use rayon::prelude::*; use crate::vector::bq::storage::{ RABIT_CODE_COLUMN, RABIT_METADATA_KEY, RabitQuantizationMetadata, RabitQuantizationStorage, - rabit_binary_code_field, rabit_ex_code_field, + RabitQueryEstimator, rabit_binary_code_field, rabit_ex_code_field, }; use crate::vector::bq::transform::{ ADD_FACTORS_FIELD, EX_ADD_FACTORS_FIELD, EX_SCALE_FACTORS_FIELD, SCALE_FACTORS_FIELD, @@ -236,6 +236,7 @@ impl RabitQuantizer { code_dim: code_dim as u32, num_bits, packed: false, + query_estimator: RabitQueryEstimator::RawQuery, } } RQRotationType::Fast => RabitQuantizationMetadata { @@ -246,6 +247,7 @@ impl RabitQuantizer { code_dim: code_dim as u32, num_bits, packed: false, + query_estimator: RabitQueryEstimator::RawQuery, }, }; Self { metadata } @@ -259,6 +261,10 @@ impl RabitQuantizer { self.metadata.rotation_type } + pub fn metadata_ref(&self) -> &RabitQuantizationMetadata { + &self.metadata + } + #[inline] fn fast_rotation_signs(&self) -> &[u8] { self.metadata @@ -324,7 +330,7 @@ impl RabitQuantizer { } } - pub(crate) fn rotate_fsl_to_f32(&self, vectors: &FixedSizeListArray) -> Result> { + pub fn rotate_fsl_to_f32(&self, vectors: &FixedSizeListArray) -> Result> { match vectors.value_type() { DataType::Float16 => self.rotate_fsl_to_f32_typed::(vectors), DataType::Float32 => self.rotate_fsl_to_f32_typed::(vectors), @@ -522,16 +528,6 @@ impl RabitQuantizer { T::Native: AsPrimitive + Sync, { let ex_bits = rabit_ex_bits(self.metadata.num_bits)?; - if ex_bits == 0 { - return Ok(RabitQuantizedBatch { - binary_codes: self.transform::(residual_vectors)?, - ex_codes: None, - ex_res_dot_dists: None, - rotated_residuals: None, - ex_code_values: None, - }); - } - let n = residual_vectors.len(); let dim = self.dim(); debug_assert_eq!(residual_vectors.values().len(), n * dim); @@ -546,10 +542,10 @@ impl RabitQuantizer { let ex_code_bytes = rabit_ex_code_bytes(code_dim, ex_bits)?; let mut encoded_codes = vec![0u8; n * code_bytes]; - let mut encoded_ex_codes = vec![0u8; n * ex_code_bytes]; - let mut ex_res_dot_dists = vec![0.0f32; n]; + let mut encoded_ex_codes = (ex_bits != 0).then(|| vec![0u8; n * ex_code_bytes]); + let mut ex_res_dot_dists = (ex_bits != 0).then(|| vec![0.0f32; n]); let mut rotated_residuals = vec![0.0f32; n * code_dim]; - let mut ex_code_values = vec![0u8; n * code_dim]; + let mut ex_code_values = (ex_bits != 0).then(|| vec![0u8; n * code_dim]); match self.rotation_type() { RQRotationType::Matrix => { @@ -560,67 +556,67 @@ impl RabitQuantizer { encoded_codes .chunks_mut(code_bytes) - .zip(encoded_ex_codes.chunks_mut(ex_code_bytes)) .zip(rotated_residuals.chunks_mut(code_dim)) - .zip(ex_code_values.chunks_mut(code_dim)) - .zip(ex_res_dot_dists.iter_mut()) .enumerate() - .for_each( - |( - row_idx, - ((((code_dst, ex_dst), rotated_dst), ex_values_dst), ex_dot_dst), - )| { - for (dst, value) in rotated_dst - .iter_mut() - .zip(rotated_vectors.column(row_idx).iter()) - { - *dst = *value; - } - pack_sign_bits(code_dst, rotated_dst); - *ex_dot_dst = - quantize_ex_code(rotated_dst, ex_bits, ex_dst, ex_values_dst); - }, - ); + .for_each(|(row_idx, (code_dst, rotated_dst))| { + for (dst, value) in rotated_dst + .iter_mut() + .zip(rotated_vectors.column(row_idx).iter()) + { + *dst = *value; + } + pack_sign_bits(code_dst, rotated_dst); + }); } RQRotationType::Fast => { let signs = self.fast_rotation_signs(); encoded_codes .par_chunks_mut(code_bytes) - .zip(encoded_ex_codes.par_chunks_mut(ex_code_bytes)) .zip(rotated_residuals.par_chunks_mut(code_dim)) - .zip(ex_code_values.par_chunks_mut(code_dim)) - .zip(ex_res_dot_dists.par_iter_mut()) .zip(values.par_chunks_exact(dim)) - .for_each_init( - || (), - |_, - ( - ((((code_dst, ex_dst), rotated_dst), ex_values_dst), ex_dot_dst), - input, - )| { - apply_fast_rotation(input, rotated_dst, signs); - pack_sign_bits(code_dst, rotated_dst); - *ex_dot_dst = - quantize_ex_code(rotated_dst, ex_bits, ex_dst, ex_values_dst); - }, - ); + .for_each(|((code_dst, rotated_dst), input)| { + apply_fast_rotation(input, rotated_dst, signs); + pack_sign_bits(code_dst, rotated_dst); + }); } } + if ex_bits != 0 { + let encoded_ex_codes = encoded_ex_codes + .as_mut() + .expect("ex-code buffer should exist for multi-bit RQ"); + let ex_res_dot_dists = ex_res_dot_dists + .as_mut() + .expect("ex dot buffer should exist for multi-bit RQ"); + let ex_code_values = ex_code_values + .as_mut() + .expect("ex-code value buffer should exist for multi-bit RQ"); + encoded_ex_codes + .par_chunks_mut(ex_code_bytes) + .zip(ex_code_values.par_chunks_mut(code_dim)) + .zip(ex_res_dot_dists.par_iter_mut()) + .zip(rotated_residuals.par_chunks(code_dim)) + .for_each(|(((ex_dst, ex_values_dst), ex_dot_dst), rotated)| { + *ex_dot_dst = quantize_ex_code(rotated, ex_bits, ex_dst, ex_values_dst); + }); + } + let binary_codes = UInt8Array::from(encoded_codes); - let ex_codes = UInt8Array::from(encoded_ex_codes); + let ex_codes = encoded_ex_codes.map(UInt8Array::from); Ok(RabitQuantizedBatch { binary_codes: Arc::new(FixedSizeListArray::try_new_from_values( binary_codes, code_bytes as i32, )?), - ex_codes: Some(Arc::new(FixedSizeListArray::try_new_from_values( - ex_codes, - ex_code_bytes as i32, - )?)), - ex_res_dot_dists: Some(ex_res_dot_dists), + ex_codes: ex_codes + .map(|ex_codes| { + FixedSizeListArray::try_new_from_values(ex_codes, ex_code_bytes as i32) + .map(|array| Arc::new(array) as ArrayRef) + }) + .transpose()?, + ex_res_dot_dists, rotated_residuals: Some(rotated_residuals), - ex_code_values: Some(ex_code_values), + ex_code_values, }) } } diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index 4c2aeb7363e..e723b1abfe9 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -19,7 +19,7 @@ use itertools::Itertools; use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, RecordBatchExt}; use lance_core::{Error, ROW_ID, Result}; use lance_file::previous::reader::FileReader as PreviousFileReader; -use lance_linalg::distance::{DistanceType, Dot}; +use lance_linalg::distance::{DistanceType, Dot, dot}; use lance_linalg::simd::{ self, dist_table::{BATCH_SIZE, PERM0, PERM0_INVERSE}, @@ -47,7 +47,7 @@ use crate::vector::bq::{ }; use crate::vector::pq::storage::transpose; use crate::vector::quantizer::{QuantizerMetadata, QuantizerStorage}; -use crate::vector::storage::{DistCalculator, QueryResidual, VectorStore}; +use crate::vector::storage::{DistCalculator, QueryResidual, RabitRawQueryContext, VectorStore}; pub const RABIT_METADATA_KEY: &str = "lance:rabit"; pub const RABIT_CODE_COLUMN: &str = "_rabit_codes"; @@ -55,6 +55,13 @@ pub const RABIT_EX_CODE_COLUMN: &str = "__ex_codes"; pub const SEGMENT_LENGTH: usize = 4; pub const SEGMENT_NUM_CODES: usize = 1 << SEGMENT_LENGTH; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RabitQueryEstimator { + ResidualQuery, + RawQuery, +} + pub fn rabit_binary_code_field(rotated_dim: usize) -> Field { Field::new( RABIT_CODE_COLUMN, @@ -98,6 +105,8 @@ pub struct RabitQuantizationMetadata { pub code_dim: u32, pub num_bits: u8, pub packed: bool, + #[serde(default = "default_query_estimator_compat")] + pub query_estimator: RabitQueryEstimator, } impl RabitQuantizationMetadata { @@ -122,6 +131,11 @@ fn default_rotation_type_compat() -> RQRotationType { RQRotationType::Matrix } +fn default_query_estimator_compat() -> RabitQueryEstimator { + // Released IVF_RQ indexes predate this marker and used residual queries. + RabitQueryEstimator::ResidualQuery +} + impl RabitQuantizationMetadata { fn code_dim(&self) -> usize { self.rotated_dim() @@ -198,6 +212,37 @@ impl RabitQuantizationMetadata { } } } + + pub fn prepare_raw_query_context(&self, query: &dyn Array) -> Result { + validate_rq_num_bits(self.num_bits)?; + let code_dim = self.code_dim(); + let ex_bits = rabit_ex_bits(self.num_bits)?; + let dist_table_len = code_dim * 4; + let ex_dist_table_len = if ex_bits == 0 { + 0 + } else { + code_dim * (1usize << ex_bits) + }; + + let mut rotated_query = vec![0.0; code_dim]; + self.rotate_vector_with_residual_into(query, None, &mut rotated_query); + + let mut dist_table = vec![0.0; dist_table_len]; + build_dist_table_direct_into::(&rotated_query, &mut dist_table); + + let mut ex_dist_table = vec![0.0; ex_dist_table_len]; + build_ex_dist_table_direct_into(&rotated_query, ex_bits, &mut ex_dist_table); + + let sum_q = rotated_query.iter().copied().sum(); + Ok(RabitRawQueryContext { + code_dim, + ex_bits, + rotated_query, + dist_table, + ex_dist_table, + sum_q, + }) + } } impl DeepSizeOf for RabitQuantizationMetadata { @@ -285,6 +330,9 @@ pub struct RabitQuantizationStorage { codes: FixedSizeListArray, add_factors: Float32Array, scale_factors: Float32Array, + ex_codes: Option, + ex_add_factors: Option, + ex_scale_factors: Option, } impl DeepSizeOf for RabitQuantizationStorage { @@ -298,7 +346,7 @@ impl RabitQuantizationStorage { self.metadata.code_dim() } - fn query_factor(&self, dist_q_c: f32) -> f32 { + fn residual_query_factor(&self, dist_q_c: f32) -> f32 { match self.distance_type { DistanceType::L2 => dist_q_c, DistanceType::Cosine | DistanceType::Dot => dist_q_c - 1.0, @@ -309,22 +357,55 @@ impl RabitQuantizationStorage { } } + fn raw_query_factor( + &self, + dist_q_c: f32, + rotated_query: &[f32], + rotated_centroid: Option<&[f32]>, + ) -> f32 { + match self.distance_type { + DistanceType::L2 => dist_q_c, + DistanceType::Dot => rotated_centroid + .map(|centroid| -dot(rotated_query, centroid)) + .unwrap_or(dist_q_c - 1.0), + DistanceType::Cosine => dist_q_c - 1.0, + _ => unimplemented!( + "RabitQ does not support distance type: {}", + self.distance_type + ), + } + } + fn distance_calculator_from_parts<'a>( &'a self, dim: usize, - dist_q_c: f32, dist_table: Cow<'a, [f32]>, + ex_dist_table: Cow<'a, [f32]>, sum_q: f32, + query_factor: f32, ) -> RabitDistCalculator<'a> { + let ex_codes = self + .ex_codes + .as_ref() + .map(|codes| codes.values().as_primitive::().values().as_ref()); RabitDistCalculator::new( dim, self.metadata.num_bits, + self.metadata.query_estimator, dist_table, + ex_dist_table, sum_q, self.codes.values().as_primitive::().values(), + ex_codes, self.add_factors.values(), self.scale_factors.values(), - self.query_factor(dist_q_c), + self.ex_add_factors + .as_ref() + .map(|factors| factors.values().as_ref()), + self.ex_scale_factors + .as_ref() + .map(|factors| factors.values().as_ref()), + query_factor, ) } @@ -493,14 +574,20 @@ fn copy_subtract_f32(lhs: &[f32], rhs: &[f32], output: &mut [f32]) { pub struct RabitDistCalculator<'a> { dim: usize, + num_bits: u8, + query_estimator: RabitQueryEstimator, // n * d / 8 binary-code bytes codes: &'a [u8], + ex_codes: Option<&'a [u8]>, // this is a flattened 2D array of size d/4 * 16, // we split the query codes into d/4 chunks, each chunk is with 4 elements, // then dist_table[i][j] is the distance between the i-th query code and the code j dist_table: Cow<'a, [f32]>, + ex_dist_table: Cow<'a, [f32]>, add_factors: &'a [f32], scale_factors: &'a [f32], + ex_add_factors: Option<&'a [f32]>, + ex_scale_factors: Option<&'a [f32]>, query_factor: f32, sum_q: f32, @@ -512,19 +599,30 @@ impl<'a> RabitDistCalculator<'a> { pub fn new( dim: usize, num_bits: u8, + query_estimator: RabitQueryEstimator, dist_table: Cow<'a, [f32]>, + ex_dist_table: Cow<'a, [f32]>, sum_q: f32, codes: &'a [u8], + ex_codes: Option<&'a [u8]>, add_factors: &'a [f32], scale_factors: &'a [f32], + ex_add_factors: Option<&'a [f32]>, + ex_scale_factors: Option<&'a [f32]>, query_factor: f32, ) -> Self { Self { dim, + num_bits, + query_estimator, codes, + ex_codes, dist_table, + ex_dist_table, add_factors, scale_factors, + ex_add_factors, + ex_scale_factors, query_factor, sqrt_d: (dim as f32 * num_bits as f32).sqrt(), sum_q, @@ -550,6 +648,33 @@ where dist_table } +fn build_ex_dist_table_direct(rotated_query: &[f32], ex_bits: u8) -> Vec { + if ex_bits == 0 { + return Vec::new(); + } + let entries_per_dim = 1usize << ex_bits; + let mut dist_table = vec![0.0; rotated_query.len() * entries_per_dim]; + build_ex_dist_table_direct_into(rotated_query, ex_bits, &mut dist_table); + dist_table +} + +fn build_ex_dist_table_direct_into(rotated_query: &[f32], ex_bits: u8, dist_table: &mut [f32]) { + if ex_bits == 0 { + debug_assert!(dist_table.is_empty()); + return; + } + let entries_per_dim = 1usize << ex_bits; + debug_assert_eq!(dist_table.len(), rotated_query.len() * entries_per_dim); + for (query_value, table) in rotated_query + .iter() + .zip(dist_table.chunks_exact_mut(entries_per_dim)) + { + for (code, value) in table.iter_mut().enumerate() { + *value = *query_value * code as f32; + } + } +} + fn build_dist_table_direct_into(qc: &[T::Native], dist_table: &mut [f32]) where T::Native: AsPrimitive, @@ -616,6 +741,42 @@ fn quantize_dist_table_into(dist_table: &[f32], quantized_dist_table: &mut Vec u8 { + debug_assert!(ex_bits > 0); + let mut value = 0u8; + let bit_offset = dim_idx * ex_bits as usize; + for bit_idx in 0..ex_bits as usize { + let src_bit = bit_offset + bit_idx; + if (row_codes[src_bit / u8::BITS as usize] >> (src_bit % u8::BITS as usize)) & 1 != 0 { + value |= 1u8 << bit_idx; + } + } + value +} + +#[inline] +fn compute_single_rq_ex_distance( + ex_codes: &[u8], + id: usize, + ex_code_len: usize, + ex_bits: u8, + dim: usize, + ex_dist_table: &[f32], +) -> f32 { + if ex_bits == 0 { + return 0.0; + } + let entries_per_dim = 1usize << ex_bits; + let row_codes = &ex_codes[id * ex_code_len..(id + 1) * ex_code_len]; + (0..dim) + .map(|dim_idx| { + let code = packed_ex_code_value(row_codes, dim_idx, ex_bits) as usize; + ex_dist_table[dim_idx * entries_per_dim + code] + }) + .sum() +} + impl DistCalculator for RabitDistCalculator<'_> { #[inline(always)] fn distance(&self, id: u32) -> f32 { @@ -625,9 +786,45 @@ impl DistCalculator for RabitDistCalculator<'_> { let dist = compute_single_rq_distance(self.codes, id, num_vectors, code_len, &self.dist_table); - // distance between quantized vector and query vector - let dist_vq_qr = (2.0 * dist - self.sum_q) / self.sqrt_d; - dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor + match self.query_estimator { + RabitQueryEstimator::ResidualQuery => { + // distance between quantized residual vector and residual query vector + let dist_vq_qr = (2.0 * dist - self.sum_q) / self.sqrt_d; + dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor + } + RabitQueryEstimator::RawQuery => { + let ex_bits = self.num_bits - 1; + if ex_bits == 0 { + let binary_dot = dist - 0.5 * self.sum_q; + return binary_dot * self.scale_factors[id] + + self.add_factors[id] + + self.query_factor; + } + + let ex_codes = self + .ex_codes + .expect("raw-query multi-bit RQ requires ex codes"); + let ex_add_factors = self + .ex_add_factors + .expect("raw-query multi-bit RQ requires ex add factors"); + let ex_scale_factors = self + .ex_scale_factors + .expect("raw-query multi-bit RQ requires ex scale factors"); + let ex_code_len = rabit_ex_code_bytes(self.dim, ex_bits) + .expect("RabitQ num_bits should be validated"); + let ex_dist = compute_single_rq_ex_distance( + ex_codes, + id, + ex_code_len, + ex_bits, + self.dim, + &self.ex_dist_table, + ); + let code_bias = -((1u32 << ex_bits) as f32 - 0.5); + let full_dot = (1u32 << ex_bits) as f32 * dist + ex_dist + code_bias * self.sum_q; + full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor + } + } } #[inline(always)] @@ -661,6 +858,17 @@ impl DistCalculator for RabitDistCalculator<'_> { return; } + if self.query_estimator == RabitQueryEstimator::RawQuery && self.num_bits > 1 { + dists.clear(); + dists.reserve(n); + for id in 0..n { + dists.push(self.distance(id as u32)); + } + quantized_dists.clear(); + quantized_dists_table.clear(); + return; + } + let (qmin, qmax) = quantize_dist_table_into(&self.dist_table, quantized_dists_table); let remainder = n % BATCH_SIZE; let simd_len = n - remainder; @@ -695,9 +903,20 @@ impl DistCalculator for RabitDistCalculator<'_> { .enumerate() .for_each(|(id, (dist, q_dist))| { let dist_vq = (*q_dist as f32) * range + sum_min; - let dist_vq_qr = (2.0 * dist_vq - self.sum_q) / self.sqrt_d; - *dist = - dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor; + *dist = match self.query_estimator { + RabitQueryEstimator::ResidualQuery => { + let dist_vq_qr = (2.0 * dist_vq - self.sum_q) / self.sqrt_d; + dist_vq_qr * self.scale_factors[id] + + self.add_factors[id] + + self.query_factor + } + RabitQueryEstimator::RawQuery => { + let binary_dot = dist_vq - 0.5 * self.sum_q; + binary_dot * self.scale_factors[id] + + self.add_factors[id] + + self.query_factor + } + }; }); remainder_dists @@ -750,9 +969,21 @@ impl VectorStore for RabitQuantizationStorage { let code_dim = self.code_dim(); let rotated_qr = self.rotate_query_vector(code_dim, &qr); let dist_table = build_dist_table_direct::(&rotated_qr); + let ex_bits = self.metadata.num_bits - 1; + let ex_dist_table = build_ex_dist_table_direct(&rotated_qr, ex_bits); + let query_factor = match self.metadata.query_estimator { + RabitQueryEstimator::ResidualQuery => self.residual_query_factor(dist_q_c), + RabitQueryEstimator::RawQuery => self.raw_query_factor(dist_q_c, &rotated_qr, None), + }; let sum_q = rotated_qr.into_iter().sum(); - self.distance_calculator_from_parts(code_dim, dist_q_c, Cow::Owned(dist_table), sum_q) + self.distance_calculator_from_parts( + code_dim, + Cow::Owned(dist_table), + Cow::Owned(ex_dist_table), + sum_q, + query_factor, + ) } // qr = (q-c) @@ -761,15 +992,44 @@ impl VectorStore for RabitQuantizationStorage { &'a self, qr: Arc, dist_q_c: f32, - residual: Option>, + residual: Option>, f32_scratch: &'a mut Vec, ) -> Self::DistanceCalculator<'a> { let code_dim = self.code_dim(); + if let ( + RabitQueryEstimator::RawQuery, + Some(QueryResidual::RabitRawQuery { + rotated_centroid, + query: Some(raw_query), + }), + ) = (self.metadata.query_estimator, residual) + { + debug_assert_eq!(raw_query.code_dim, code_dim); + debug_assert_eq!(raw_query.ex_bits, self.metadata.num_bits - 1); + let query_factor = + self.raw_query_factor(dist_q_c, &raw_query.rotated_query, rotated_centroid); + return self.distance_calculator_from_parts( + code_dim, + Cow::Borrowed(&raw_query.dist_table), + Cow::Borrowed(&raw_query.ex_dist_table), + raw_query.sum_q, + query_factor, + ); + } + let dist_table_len = code_dim * 4; - f32_scratch.resize(code_dim + dist_table_len, 0.0); + let ex_bits = self.metadata.num_bits - 1; + let ex_dist_table_len = if ex_bits == 0 { + 0 + } else { + code_dim * (1usize << ex_bits) + }; + f32_scratch.resize(code_dim + dist_table_len + ex_dist_table_len, 0.0); + let query_factor; let sum_q = { - let (rotated_qr, dist_table) = f32_scratch.split_at_mut(code_dim); + let (rotated_qr, remaining) = f32_scratch.split_at_mut(code_dim); + let (dist_table, ex_dist_table) = remaining.split_at_mut(dist_table_len); match residual { Some(QueryResidual::Centroid(residual_centroid)) => { self.rotate_query_vector_into( @@ -779,19 +1039,36 @@ impl VectorStore for RabitQuantizationStorage { rotated_qr, ); } - None => { + Some(QueryResidual::RabitRawQuery { .. }) | None => { self.rotate_query_vector_into(code_dim, &qr, None, rotated_qr); } } + query_factor = match (self.metadata.query_estimator, residual) { + (RabitQueryEstimator::ResidualQuery, _) => self.residual_query_factor(dist_q_c), + ( + RabitQueryEstimator::RawQuery, + Some(QueryResidual::RabitRawQuery { + rotated_centroid, .. + }), + ) => self.raw_query_factor(dist_q_c, rotated_qr, rotated_centroid), + (RabitQueryEstimator::RawQuery, _) => { + self.raw_query_factor(dist_q_c, rotated_qr, None) + } + }; build_dist_table_direct_into::(rotated_qr, dist_table); + build_ex_dist_table_direct_into(rotated_qr, ex_bits, ex_dist_table); rotated_qr.iter().copied().sum() }; self.distance_calculator_from_parts( code_dim, - dist_q_c, Cow::Borrowed(&f32_scratch[code_dim..code_dim + dist_table_len]), + Cow::Borrowed( + &f32_scratch + [code_dim + dist_table_len..code_dim + dist_table_len + ex_dist_table_len], + ), sum_q, + query_factor, ) } @@ -981,8 +1258,11 @@ impl QuantizerStorage for RabitQuantizationStorage { .as_primitive::() .clone(); let ex_bits = rabit_ex_bits(metadata.num_bits)?; + let mut ex_codes = None; + let mut ex_add_factors = None; + let mut ex_scale_factors = None; if ex_bits != 0 { - let ex_codes = batch + let codes = batch .column_by_name(RABIT_EX_CODE_COLUMN) .ok_or_else(|| { Error::invalid_input(format!( @@ -993,34 +1273,56 @@ impl QuantizerStorage for RabitQuantizationStorage { .as_fixed_size_list() .clone(); let expected_ex_code_bytes = rabit_ex_code_bytes(metadata.rotated_dim(), ex_bits)?; - if ex_codes.value_length() as usize != expected_ex_code_bytes { + if codes.value_length() as usize != expected_ex_code_bytes { return Err(Error::invalid_input(format!( "RabitQ ex-code byte width mismatch: column {} has {} bytes, metadata rotated_dim={} ex_bits={} requires {} bytes", RABIT_EX_CODE_COLUMN, - ex_codes.value_length(), + codes.value_length(), metadata.rotated_dim(), ex_bits, expected_ex_code_bytes ))); } - batch - .column_by_name(EX_ADD_FACTORS_COLUMN) - .ok_or_else(|| { - Error::invalid_input(format!( - "RabitQ num_bits={} requires {} column", - metadata.num_bits, EX_ADD_FACTORS_COLUMN - )) - })? - .as_primitive::(); - batch - .column_by_name(EX_SCALE_FACTORS_COLUMN) - .ok_or_else(|| { - Error::invalid_input(format!( - "RabitQ num_bits={} requires {} column", - metadata.num_bits, EX_SCALE_FACTORS_COLUMN - )) - })? - .as_primitive::(); + ex_codes = Some(codes); + ex_add_factors = Some( + batch + .column_by_name(EX_ADD_FACTORS_COLUMN) + .ok_or_else(|| { + Error::invalid_input(format!( + "RabitQ num_bits={} requires {} column", + metadata.num_bits, EX_ADD_FACTORS_COLUMN + )) + })? + .as_primitive::() + .clone(), + ); + ex_scale_factors = Some( + batch + .column_by_name(EX_SCALE_FACTORS_COLUMN) + .ok_or_else(|| { + Error::invalid_input(format!( + "RabitQ num_bits={} requires {} column", + metadata.num_bits, EX_SCALE_FACTORS_COLUMN + )) + })? + .as_primitive::() + .clone(), + ); + } else if metadata.query_estimator == RabitQueryEstimator::RawQuery { + if batch.column_by_name(EX_ADD_FACTORS_COLUMN).is_some() + || batch.column_by_name(EX_SCALE_FACTORS_COLUMN).is_some() + || batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() + { + return Err(Error::invalid_input( + "RabitQ num_bits=1 raw-query indexes must not contain ex-code columns" + .to_string(), + )); + } + } else if batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() { + return Err(Error::invalid_input(format!( + "RabitQ num_bits={} does not support {} column", + metadata.num_bits, RABIT_EX_CODE_COLUMN + ))); } let (batch, codes) = if !metadata.packed { @@ -1043,6 +1345,9 @@ impl QuantizerStorage for RabitQuantizationStorage { codes, add_factors, scale_factors, + ex_codes, + ex_add_factors, + ex_scale_factors, }) } @@ -1108,6 +1413,15 @@ impl QuantizerStorage for RabitQuantizationStorage { let scale_factors = batch[SCALE_FACTORS_COLUMN] .as_primitive::() .clone(); + let ex_codes = batch + .column_by_name(RABIT_EX_CODE_COLUMN) + .map(|codes| codes.as_fixed_size_list().clone()); + let ex_add_factors = batch + .column_by_name(EX_ADD_FACTORS_COLUMN) + .map(|factors| factors.as_primitive::().clone()); + let ex_scale_factors = batch + .column_by_name(EX_SCALE_FACTORS_COLUMN) + .map(|factors| factors.as_primitive::().clone()); Ok(Self { metadata: self.metadata.clone(), @@ -1116,6 +1430,9 @@ impl QuantizerStorage for RabitQuantizationStorage { codes, add_factors, scale_factors, + ex_codes, + ex_add_factors, + ex_scale_factors, row_ids: new_row_ids, }) } @@ -1298,7 +1615,8 @@ mod tests { fn test_dist_calculator_with_scratch_applies_residual_centroid_without_residual_array() { let code_dim = 64usize; let original_codes = make_test_codes(50, code_dim as i32); - let metadata = make_test_metadata(original_codes.value_length() as usize * 8); + let mut metadata = make_test_metadata(original_codes.value_length() as usize * 8); + metadata.query_estimator = RabitQueryEstimator::ResidualQuery; let storage = RabitQuantizationStorage::try_from_batch( make_test_batch(original_codes), &metadata, @@ -1337,7 +1655,8 @@ mod tests { fn test_dist_calculator_with_scratch_applies_float64_residual_before_f32_cast() { let code_dim = 64usize; let original_codes = make_test_codes(50, code_dim as i32); - let metadata = make_test_metadata(original_codes.value_length() as usize * 8); + let mut metadata = make_test_metadata(original_codes.value_length() as usize * 8); + metadata.query_estimator = RabitQueryEstimator::ResidualQuery; let storage = RabitQuantizationStorage::try_from_batch( make_test_batch(original_codes), &metadata, @@ -1445,6 +1764,21 @@ mod tests { .metadata(None) } + #[test] + fn test_rabit_metadata_defaults_old_indexes_to_residual_query() { + let metadata: RabitQuantizationMetadata = serde_json::from_str( + r#"{"rotate_mat_position":0,"rotation_type":"matrix","code_dim":64,"num_bits":1,"packed":true}"#, + ) + .unwrap(); + assert_eq!(metadata.query_estimator, RabitQueryEstimator::ResidualQuery); + } + + #[test] + fn test_new_rabit_metadata_uses_raw_query_estimator() { + let metadata = make_test_metadata(64); + assert_eq!(metadata.query_estimator, RabitQueryEstimator::RawQuery); + } + fn make_test_batch(codes: FixedSizeListArray) -> RecordBatch { let num_rows = codes.len(); RecordBatch::try_from_iter(vec![ @@ -1528,6 +1862,164 @@ mod tests { ); } + #[test] + fn test_raw_query_multi_bit_distance_uses_ex_factors() { + let code_dim = 8usize; + let identity = Float32Array::from_iter_values( + (0..code_dim) + .flat_map(|row| (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 })), + ); + let rotate_mat = + FixedSizeListArray::try_new_from_values(identity, code_dim as i32).unwrap(); + let metadata = RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Matrix, + code_dim: code_dim as u32, + num_bits: 2, + packed: false, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let codes = + FixedSizeListArray::try_new_from_values(UInt8Array::from(vec![0xff, 0xff]), 1).unwrap(); + let ex_codes = + FixedSizeListArray::try_new_from_values(UInt8Array::from(vec![0x00, 0xff]), 1).unwrap(); + let batch = RecordBatch::try_from_iter(vec![ + (ROW_ID, Arc::new(UInt64Array::from(vec![0, 1])) as ArrayRef), + (RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0, 0.0])) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0, 0.0])) as ArrayRef, + ), + (RABIT_EX_CODE_COLUMN, Arc::new(ex_codes) as ArrayRef), + ( + EX_ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![100.0, 10.0])) as ArrayRef, + ), + ( + EX_SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![1.0, 1.0])) as ArrayRef, + ), + ]) + .unwrap(); + let storage = + RabitQuantizationStorage::try_from_batch(batch, &metadata, DistanceType::L2, None) + .unwrap(); + let query = Arc::new(Float32Array::from(vec![1.0; code_dim])) as ArrayRef; + let calc = storage.dist_calculator(query, 0.0); + + assert_eq!(calc.distance(0), 104.0); + assert_eq!(calc.distance(1), 22.0); + let mut distances = Vec::new(); + let mut u16_scratch = Vec::new(); + let mut u8_scratch = Vec::new(); + calc.distance_all_with_scratch(0, &mut distances, &mut u16_scratch, &mut u8_scratch); + assert_eq!(distances, vec![104.0, 22.0]); + } + + #[test] + fn test_raw_query_one_bit_distance_uses_binary_factors_without_ex_columns() { + let code_dim = 8usize; + let identity = Float32Array::from_iter_values( + (0..code_dim) + .flat_map(|row| (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 })), + ); + let rotate_mat = + FixedSizeListArray::try_new_from_values(identity, code_dim as i32).unwrap(); + let metadata = RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Matrix, + code_dim: code_dim as u32, + num_bits: 1, + packed: false, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let codes = + FixedSizeListArray::try_new_from_values(UInt8Array::from(vec![0xff, 0x00]), 1).unwrap(); + let storage = RabitQuantizationStorage::try_from_batch( + make_test_batch(codes), + &metadata, + DistanceType::L2, + None, + ) + .unwrap(); + let query = Arc::new(Float32Array::from(vec![1.0; code_dim])) as ArrayRef; + let calc = storage.dist_calculator(query, 3.0); + + assert_eq!(calc.distance_all(0), vec![5.0, -2.0]); + } + + #[test] + fn test_raw_query_context_matches_fallback_and_only_updates_partition_factor() { + let code_dim = 8usize; + let identity = Float32Array::from_iter_values( + (0..code_dim) + .flat_map(|row| (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 })), + ); + let rotate_mat = + FixedSizeListArray::try_new_from_values(identity, code_dim as i32).unwrap(); + let metadata = RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Matrix, + code_dim: code_dim as u32, + num_bits: 2, + packed: false, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let codes = + FixedSizeListArray::try_new_from_values(UInt8Array::from(vec![0xff, 0xff]), 1).unwrap(); + let ex_codes = + FixedSizeListArray::try_new_from_values(UInt8Array::from(vec![0x00, 0xff]), 1).unwrap(); + let storage = RabitQuantizationStorage::try_from_batch( + make_test_batch_with_ex(codes, ex_codes), + &metadata, + DistanceType::Dot, + None, + ) + .unwrap(); + let query = Arc::new(Float32Array::from(vec![1.0; code_dim])) as ArrayRef; + let rotated_centroid = vec![0.25; code_dim]; + let raw_query = metadata.prepare_raw_query_context(query.as_ref()).unwrap(); + + let mut fallback_scratch = Vec::new(); + let expected = storage + .dist_calculator_with_scratch( + query.clone(), + 123.0, + Some(QueryResidual::RabitRawQuery { + rotated_centroid: Some(&rotated_centroid), + query: None, + }), + &mut fallback_scratch, + ) + .distance_all(0); + + let mut prepared_scratch = Vec::new(); + let actual = storage + .dist_calculator_with_scratch( + query, + 456.0, + Some(QueryResidual::RabitRawQuery { + rotated_centroid: Some(&rotated_centroid), + query: Some(&raw_query), + }), + &mut prepared_scratch, + ) + .distance_all(0); + + assert_eq!(actual, expected); + assert!(prepared_scratch.is_empty()); + } + #[test] fn test_try_from_batch_canonicalizes_rq_codes_to_packed_layout() { let original_codes = make_test_codes(50, 64); diff --git a/rust/lance-index/src/vector/bq/transform.rs b/rust/lance-index/src/vector/bq/transform.rs index 391f6ab158f..c2fc0608102 100644 --- a/rust/lance-index/src/vector/bq/transform.rs +++ b/rust/lance-index/src/vector/bq/transform.rs @@ -17,7 +17,7 @@ use tracing::instrument; use crate::vector::bq::builder::RabitQuantizer; use crate::vector::bq::rabit_ex_bits; -use crate::vector::bq::storage::{RABIT_CODE_COLUMN, RABIT_EX_CODE_COLUMN}; +use crate::vector::bq::storage::{RABIT_CODE_COLUMN, RABIT_EX_CODE_COLUMN, RabitQueryEstimator}; use crate::vector::quantizer::Quantization; use crate::vector::transform::Transformer; use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN}; @@ -28,6 +28,9 @@ pub const ADD_FACTORS_COLUMN: &str = "__add_factors"; pub const SCALE_FACTORS_COLUMN: &str = "__scale_factors"; pub const EX_ADD_FACTORS_COLUMN: &str = "__add_factors_ex"; pub const EX_SCALE_FACTORS_COLUMN: &str = "__scale_factors_ex"; +pub const ERROR_FACTORS_COLUMN: &str = "__error_factors"; + +const RABIT_ERROR_EPSILON: f32 = 1.9; pub static ADD_FACTORS_FIELD: LazyLock = LazyLock::new(|| { arrow_schema::Field::new(ADD_FACTORS_COLUMN, arrow_schema::DataType::Float32, true) @@ -45,6 +48,9 @@ pub static EX_SCALE_FACTORS_FIELD: LazyLock = LazyLock::new true, ) }); +pub static ERROR_FACTORS_FIELD: LazyLock = LazyLock::new(|| { + arrow_schema::Field::new(ERROR_FACTORS_COLUMN, arrow_schema::DataType::Float32, true) +}); pub struct RQTransformer { rq: RabitQuantizer, @@ -64,7 +70,8 @@ impl RQTransformer { // for dot product, the add factor is `1 - v*c + |c|^2`, so we need to compute |c|^2 let centroids_norm_square = (distance_type == DistanceType::Dot) .then(|| Float32Array::from(norm_squared_fsl(¢roids))); - let rotated_centroids = (rq.num_bits() > 1) + let rotated_centroids = (rq.metadata_ref().query_estimator + == RabitQueryEstimator::RawQuery) .then(|| rq.rotate_fsl_to_f32(¢roids)) .transpose()?; @@ -81,8 +88,9 @@ impl RQTransformer { struct RabitRawQueryFactors { add_factors: Float32Array, scale_factors: Float32Array, - ex_add_factors: Float32Array, - ex_scale_factors: Float32Array, + error_factors: Float32Array, + ex_add_factors: Option, + ex_scale_factors: Option, } #[inline] @@ -103,6 +111,28 @@ fn binary_factor_value(rotated_residual: f32) -> f32 { } } +#[inline] +fn error_factor_value( + distance_type: DistanceType, + norm_square: f32, + binary_res_dot: f32, + code_dim: usize, +) -> f32 { + if code_dim <= 1 || norm_square <= 0.0 || binary_res_dot == 0.0 { + return 0.0; + } + + let code_norm_square = code_dim as f32 * 0.25; + let alignment = norm_square * code_norm_square / binary_res_dot.powi(2); + let angular_error = ((alignment - 1.0).max(0.0) / (code_dim as f32 - 1.0)).sqrt(); + let error = norm_square.sqrt() * RABIT_ERROR_EPSILON * angular_error; + match distance_type { + DistanceType::L2 => 2.0 * error, + DistanceType::Dot => error, + _ => unreachable!(), + } +} + #[allow(clippy::too_many_arguments)] fn compute_raw_query_factors( distance_type: DistanceType, @@ -110,8 +140,8 @@ fn compute_raw_query_factors( rotated_residuals: &[f32], rotated_centroids: &[f32], part_ids: &UInt32Array, - ex_code_values: &[u8], - ex_res_dot_dists: &[f32], + ex_code_values: Option<&[u8]>, + ex_res_dot_dists: Option<&[f32]>, ex_bits: u8, code_dim: usize, ) -> Result { @@ -124,21 +154,22 @@ fn compute_raw_query_factors( let num_rows = res_norm_square.len(); debug_assert_eq!(rotated_residuals.len(), num_rows * code_dim); - debug_assert_eq!(ex_code_values.len(), num_rows * code_dim); - debug_assert_eq!(ex_res_dot_dists.len(), num_rows); + if let Some(ex_code_values) = ex_code_values { + debug_assert_eq!(ex_code_values.len(), num_rows * code_dim); + } + if let Some(ex_res_dot_dists) = ex_res_dot_dists { + debug_assert_eq!(ex_res_dot_dists.len(), num_rows); + } + let has_ex_codes = ex_bits != 0; let ex_code_bias = -((1u32 << ex_bits) as f32 - 0.5); let mut add_factors = Vec::with_capacity(num_rows); let mut scale_factors = Vec::with_capacity(num_rows); - let mut ex_add_factors = Vec::with_capacity(num_rows); - let mut ex_scale_factors = Vec::with_capacity(num_rows); - - for (row_idx, (&norm_square, &ex_res_dot)) in res_norm_square - .values() - .iter() - .zip(ex_res_dot_dists.iter()) - .enumerate() - { + let mut error_factors = Vec::with_capacity(num_rows); + let mut ex_add_factors = has_ex_codes.then(|| Vec::with_capacity(num_rows)); + let mut ex_scale_factors = has_ex_codes.then(|| Vec::with_capacity(num_rows)); + + for (row_idx, &norm_square) in res_norm_square.values().iter().enumerate() { let part_id = part_ids.value(row_idx) as usize; let centroid_start = part_id.checked_mul(code_dim).ok_or_else(|| { Error::invalid_input(format!( @@ -164,14 +195,14 @@ fn compute_raw_query_factors( let row_end = row_start + code_dim; let residual = &rotated_residuals[row_start..row_end]; let centroid = &rotated_centroids[centroid_start..centroid_end]; - let ex_values = &ex_code_values[row_start..row_end]; + let ex_values = ex_code_values.map(|values| &values[row_start..row_end]); let mut binary_res_dot = 0.0f32; let mut binary_cent_dot = 0.0f32; let mut ex_cent_dot = 0.0f32; let mut residual_centroid_dot = 0.0f32; - for ((&residual_value, ¢roid_value), &ex_code_value) in - residual.iter().zip(centroid.iter()).zip(ex_values.iter()) + for (dim_idx, (&residual_value, ¢roid_value)) in + residual.iter().zip(centroid.iter()).enumerate() { let residual_value: f32 = residual_value; let centroid_value: f32 = centroid_value; @@ -181,30 +212,51 @@ fn compute_raw_query_factors( 0u32 }; let binary_factor = binary_factor_value(residual_value); - let ex_factor = ((binary_code << ex_bits) + ex_code_value as u32) as f32 + ex_code_bias; binary_res_dot += residual_value * binary_factor; binary_cent_dot += centroid_value * binary_factor; - ex_cent_dot += centroid_value * ex_factor; + if let Some(ex_values) = ex_values { + let ex_code_value = ex_values[dim_idx]; + let ex_factor = + ((binary_code << ex_bits) + ex_code_value as u32) as f32 + ex_code_bias; + ex_cent_dot += centroid_value * ex_factor; + } residual_centroid_dot += residual_value * centroid_value; } let binary_correction = factor_ratio(norm_square * binary_cent_dot, binary_res_dot); + let ex_res_dot = ex_res_dot_dists + .map(|values| values[row_idx]) + .unwrap_or_default(); let ex_correction = factor_ratio(norm_square * ex_cent_dot, ex_res_dot); + error_factors.push(error_factor_value( + distance_type, + norm_square, + binary_res_dot, + code_dim, + )); match distance_type { DistanceType::L2 => { add_factors.push(norm_square + 2.0 * binary_correction); scale_factors.push(factor_ratio(-2.0 * norm_square, binary_res_dot)); - ex_add_factors.push(norm_square + 2.0 * ex_correction); - ex_scale_factors.push(factor_ratio(-2.0 * norm_square, ex_res_dot)); + if let Some(ex_add_factors) = ex_add_factors.as_mut() { + ex_add_factors.push(norm_square + 2.0 * ex_correction); + } + if let Some(ex_scale_factors) = ex_scale_factors.as_mut() { + ex_scale_factors.push(factor_ratio(-2.0 * norm_square, ex_res_dot)); + } } DistanceType::Dot => { let dot_base = 1.0 - residual_centroid_dot; add_factors.push(dot_base + binary_correction); scale_factors.push(factor_ratio(-norm_square, binary_res_dot)); - ex_add_factors.push(dot_base + ex_correction); - ex_scale_factors.push(factor_ratio(-norm_square, ex_res_dot)); + if let Some(ex_add_factors) = ex_add_factors.as_mut() { + ex_add_factors.push(dot_base + ex_correction); + } + if let Some(ex_scale_factors) = ex_scale_factors.as_mut() { + ex_scale_factors.push(factor_ratio(-norm_square, ex_res_dot)); + } } _ => unreachable!(), } @@ -213,8 +265,9 @@ fn compute_raw_query_factors( Ok(RabitRawQueryFactors { add_factors: Float32Array::from(add_factors), scale_factors: Float32Array::from(scale_factors), - ex_add_factors: Float32Array::from(ex_add_factors), - ex_scale_factors: Float32Array::from(ex_scale_factors), + error_factors: Float32Array::from(error_factors), + ex_add_factors: ex_add_factors.map(Float32Array::from), + ex_scale_factors: ex_scale_factors.map(Float32Array::from), }) } @@ -274,8 +327,8 @@ impl Transformer for RQTransformer { debug_assert_eq!(codes_fsl.len(), batch.num_rows()); let mut batch = batch.try_with_column(self.rq.field(), rq_codes.binary_codes)?; - if self.rq.num_bits() == 1 { - // Preserve the released 1-bit residual-query estimator and factor layout. + if self.rq.metadata_ref().query_estimator == RabitQueryEstimator::ResidualQuery { + // Preserve the released residual-query estimator and factor layout. let ip_rq_res = match residual_vectors.value_type() { DataType::Float16 => Float32Array::from( self.rq @@ -356,40 +409,34 @@ impl Transformer for RQTransformer { .try_with_column(ADD_FACTORS_FIELD.clone(), Arc::new(add_factors))? .try_with_column(SCALE_FACTORS_FIELD.clone(), Arc::new(scale_factors))?; } else { - // Multi-bit RQ is stored for the RaBitQ-Library raw-query estimator. - // Search remains gated until that query path lands. - let ex_codes = rq_codes.ex_codes.ok_or_else(|| { - Error::internal("RabitQ multi-bit quantization did not return ex codes".to_string()) - })?; - let ex_res_dot_dists = rq_codes.ex_res_dot_dists.ok_or_else(|| { - Error::internal( - "RabitQ multi-bit quantization did not return ex dot factors".to_string(), - ) - })?; + // New RQ indexes use the RaBitQ-Library raw-query estimator. + let ex_bits = rabit_ex_bits(self.rq.num_bits())?; + let ex_codes = rq_codes.ex_codes; + let ex_res_dot_dists = rq_codes.ex_res_dot_dists; let rotated_residuals = rq_codes.rotated_residuals.ok_or_else(|| { - Error::internal( - "RabitQ multi-bit quantization did not return rotated residuals".to_string(), - ) - })?; - let ex_code_values = rq_codes.ex_code_values.ok_or_else(|| { - Error::internal( - "RabitQ multi-bit quantization did not return ex code values".to_string(), - ) + Error::internal("RabitQ quantization did not return rotated residuals".to_string()) })?; + let ex_code_values = rq_codes.ex_code_values; + if ex_bits != 0 + && (ex_codes.is_none() || ex_res_dot_dists.is_none() || ex_code_values.is_none()) + { + return Err(Error::internal( + "RabitQ multi-bit quantization did not return split-code values".to_string(), + )); + } let part_ids = batch[PART_ID_COLUMN].as_primitive::(); let rotated_centroids = self.rotated_centroids.as_ref().ok_or_else(|| { - Error::internal("RabitQ multi-bit transformer is missing rotated centroids") + Error::internal("RabitQ raw-query transformer is missing rotated centroids") })?; - let ex_bits = rabit_ex_bits(self.rq.num_bits())?; let raw_query_factors = compute_raw_query_factors( self.distance_type, &res_norm_square, &rotated_residuals, rotated_centroids, part_ids, - &ex_code_values, - &ex_res_dot_dists, + ex_code_values.as_deref(), + ex_res_dot_dists.as_deref(), ex_bits, self.rq.dim(), )?; @@ -404,21 +451,28 @@ impl Transformer for RQTransformer { Arc::new(raw_query_factors.scale_factors), )? .try_with_column( + ERROR_FACTORS_FIELD.clone(), + Arc::new(raw_query_factors.error_factors), + )?; + + if let Some(ex_codes) = ex_codes { + batch = batch.try_with_column( crate::vector::bq::storage::rabit_ex_code_field( self.rq.dim(), self.rq.num_bits(), )? .expect("ex-code field should exist for num_bits > 1"), ex_codes, - )? - .try_with_column( - EX_ADD_FACTORS_FIELD.clone(), - Arc::new(raw_query_factors.ex_add_factors), - )? - .try_with_column( - EX_SCALE_FACTORS_FIELD.clone(), - Arc::new(raw_query_factors.ex_scale_factors), )?; + } + if let Some(ex_add_factors) = raw_query_factors.ex_add_factors { + batch = batch + .try_with_column(EX_ADD_FACTORS_FIELD.clone(), Arc::new(ex_add_factors))?; + } + if let Some(ex_scale_factors) = raw_query_factors.ex_scale_factors { + batch = batch + .try_with_column(EX_SCALE_FACTORS_FIELD.clone(), Arc::new(ex_scale_factors))?; + } } let batch = batch @@ -445,8 +499,8 @@ mod tests { use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN}; use super::{ - ADD_FACTORS_COLUMN, EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN, RQTransformer, - compute_raw_query_factors, + ADD_FACTORS_COLUMN, ERROR_FACTORS_COLUMN, EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN, + RQTransformer, compute_raw_query_factors, error_factor_value, }; #[test] @@ -457,7 +511,6 @@ mod tests { .unwrap(); let transformer = RQTransformer::new(rq.clone(), DistanceType::L2, centroids, "vector").unwrap(); - assert!(transformer.rotated_centroids.is_some()); let residual_vectors = FixedSizeListArray::try_new_from_values( Float32Array::from(vec![ @@ -522,18 +575,27 @@ mod tests { assert!(transformed.column_by_name("vector").is_none()); assert!(transformed.column_by_name(CENTROID_DIST_COLUMN).is_none()); assert!(transformed.column_by_name(ADD_FACTORS_COLUMN).is_some()); + assert!(transformed.column_by_name(ERROR_FACTORS_COLUMN).is_some()); } #[test] - fn test_rq_transformer_caches_rotated_centroids_only_for_multi_bit() { + fn test_rq_transformer_caches_rotated_centroids_for_raw_query() { let centroids = FixedSizeListArray::try_new_from_values(Float32Array::from(vec![0.0f32; 8]), 8) .unwrap(); - let binary_rq = + let raw_query_rq = RabitQuantizer::new_with_rotation::(1, 8, RQRotationType::Fast); - let binary_transformer = - RQTransformer::new(binary_rq, DistanceType::L2, centroids.clone(), "vector").unwrap(); - assert!(binary_transformer.rotated_centroids.is_none()); + let raw_query_transformer = + RQTransformer::new(raw_query_rq, DistanceType::L2, centroids.clone(), "vector") + .unwrap(); + assert_eq!( + raw_query_transformer + .rotated_centroids + .as_ref() + .unwrap() + .len(), + 8 + ); let multi_bit_rq = RabitQuantizer::new_with_rotation::(4, 8, RQRotationType::Fast); @@ -564,8 +626,8 @@ mod tests { &rotated_residuals, &rotated_centroids, &part_ids, - &ex_code_values, - &ex_res_dot_dists, + Some(&ex_code_values), + Some(&ex_res_dot_dists), 1, 2, ) @@ -573,12 +635,17 @@ mod tests { assert!((factors.add_factors.value(0) - 1.6666667).abs() < 1e-5); assert!((factors.scale_factors.value(0) + 6.6666665).abs() < 1e-5); - assert!((factors.ex_add_factors.value(0) - 1.6666667).abs() < 1e-5); - assert!((factors.ex_scale_factors.value(0) + 2.2222223).abs() < 1e-5); + let expected_error = error_factor_value(DistanceType::L2, 5.0, 1.5, 2); + assert!((factors.error_factors.value(0) - expected_error).abs() < 1e-5); + let ex_add_factors = factors.ex_add_factors.unwrap(); + let ex_scale_factors = factors.ex_scale_factors.unwrap(); + assert!((ex_add_factors.value(0) - 1.6666667).abs() < 1e-5); + assert!((ex_scale_factors.value(0) + 2.2222223).abs() < 1e-5); assert_eq!(factors.add_factors.value(1), 7.0); assert_eq!(factors.scale_factors.value(1), 0.0); - assert_eq!(factors.ex_add_factors.value(1), 7.0); - assert_eq!(factors.ex_scale_factors.value(1), 0.0); + assert_eq!(factors.error_factors.value(1), 0.0); + assert_eq!(ex_add_factors.value(1), 7.0); + assert_eq!(ex_scale_factors.value(1), 0.0); } #[test] @@ -596,8 +663,8 @@ mod tests { &rotated_residuals, &rotated_centroids, &part_ids, - &ex_code_values, - &ex_res_dot_dists, + Some(&ex_code_values), + Some(&ex_res_dot_dists), 1, 2, ) @@ -605,7 +672,11 @@ mod tests { assert!((factors.add_factors.value(0) + 2.6666667).abs() < 1e-5); assert!((factors.scale_factors.value(0) + 3.3333333).abs() < 1e-5); - assert!((factors.ex_add_factors.value(0) + 2.6666667).abs() < 1e-5); - assert!((factors.ex_scale_factors.value(0) + 1.1111112).abs() < 1e-5); + let expected_error = error_factor_value(DistanceType::Dot, 5.0, 1.5, 2); + assert!((factors.error_factors.value(0) - expected_error).abs() < 1e-5); + let ex_add_factors = factors.ex_add_factors.unwrap(); + let ex_scale_factors = factors.ex_scale_factors.unwrap(); + assert!((ex_add_factors.value(0) + 2.6666667).abs() < 1e-5); + assert!((ex_scale_factors.value(0) + 1.1111112).abs() < 1e-5); } } diff --git a/rust/lance-index/src/vector/distributed/index_merger.rs b/rust/lance-index/src/vector/distributed/index_merger.rs index e93984bbcca..e003cf52599 100755 --- a/rust/lance-index/src/vector/distributed/index_merger.rs +++ b/rust/lance-index/src/vector/distributed/index_merger.rs @@ -1531,7 +1531,7 @@ mod tests { use prost::Message; use crate::vector::bq::RQRotationType; - use crate::vector::bq::storage::RABIT_EX_CODE_COLUMN; + use crate::vector::bq::storage::{RABIT_EX_CODE_COLUMN, RabitQueryEstimator}; use crate::vector::bq::transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}; lance_testing::define_stage_event_progress!( RecordingProgress, @@ -2317,6 +2317,7 @@ mod tests { code_dim: 16, num_bits: 1, packed: false, + query_estimator: RabitQueryEstimator::RawQuery, }; write_rq_partial_aux( @@ -2452,6 +2453,7 @@ mod tests { code_dim: 16, num_bits: 4, packed: false, + query_estimator: RabitQueryEstimator::RawQuery, }; write_rq_partial_aux( diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index 1443a1f355d..36974180d41 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -136,9 +136,23 @@ impl QueryScratchCapacity { } } +#[derive(Debug)] +pub struct RabitRawQueryContext { + pub code_dim: usize, + pub ex_bits: u8, + pub rotated_query: Vec, + pub dist_table: Vec, + pub ex_dist_table: Vec, + pub sum_q: f32, +} + #[derive(Clone, Copy)] pub enum QueryResidual<'a> { Centroid(&'a dyn arrow_array::Array), + RabitRawQuery { + rotated_centroid: Option<&'a [f32]>, + query: Option<&'a RabitRawQueryContext>, + }, } #[derive(Debug)] @@ -295,7 +309,7 @@ pub trait VectorStore: Send + Sync + Sized + Clone { &'a self, query: ArrayRef, dist_q_c: f32, - _residual: Option>, + _residual: Option>, _f32_scratch: &'a mut Vec, ) -> Self::DistanceCalculator<'a> { self.dist_calculator(query, dist_q_c) diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 3a9afeca886..c7d405ec7b1 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -29,7 +29,9 @@ use lance_index::metrics::NoOpMetricsCollector; use lance_index::optimize::OptimizeOptions; use lance_index::progress::{IndexBuildProgress, noop_progress}; use lance_index::vector::bq::builder::RabitQuantizer; -use lance_index::vector::bq::{RQBuildParams, RQRotationType, validate_supported_rq_num_bits}; +use lance_index::vector::bq::{ + RABIT_BINARY_NUM_BITS, RQBuildParams, RQRotationType, validate_supported_rq_num_bits, +}; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::ivf::builder::recommended_num_partitions; @@ -552,10 +554,13 @@ async fn prepare_vector_segment_build( stages ))); }; - // Multi-bit RQ quantization/storage internals are kept available for - // split-code preparation, but public index creation stays binary-only - // until multi-bit search support lands. validate_supported_rq_num_bits(rq_params.num_bits)?; + if rq_params.num_bits > RABIT_BINARY_NUM_BITS && params.metric_type == DistanceType::Cosine + { + return Err(Error::not_supported( + "IVF_RQ num_bits>1 cosine index creation is not supported until raw-query cosine search support is implemented", + )); + } } let num_rows = dataset.count_rows(None).await?; diff --git a/rust/lance/src/index/vector/ivf/partition_serde.rs b/rust/lance/src/index/vector/ivf/partition_serde.rs index f8d13a2f0b5..7ae77abaa06 100644 --- a/rust/lance/src/index/vector/ivf/partition_serde.rs +++ b/rust/lance/src/index/vector/ivf/partition_serde.rs @@ -32,7 +32,7 @@ use lance_core::cache::CacheCodecImpl; use lance_core::{Error, Result}; use lance_index::vector::bq::RQRotationType; use lance_index::vector::bq::builder::RabitQuantizer; -use lance_index::vector::bq::storage::RabitQuantizationMetadata; +use lance_index::vector::bq::storage::{RabitQuantizationMetadata, RabitQueryEstimator}; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatMetadata, FlatQuantizer}; use lance_index::vector::pq::ProductQuantizer; use lance_index::vector::pq::storage::ProductQuantizationMetadata; @@ -437,12 +437,18 @@ struct RabitPartitionHeader { distance_type: u8, num_bits: u8, code_dim: u32, + #[serde(default = "default_rabit_query_estimator")] + query_estimator: RabitQueryEstimator, /// 0 = Matrix, 1 = Fast rotation_type: u8, /// Fast rotation signs (only set when rotation_type == Fast). fast_rotation_signs: Option>, } +fn default_rabit_query_estimator() -> RabitQueryEstimator { + RabitQueryEstimator::ResidualQuery +} + impl CacheCodecImpl for PartitionEntry { fn serialize(&self, writer: &mut dyn Write) -> Result<()> { let metadata = self.storage.metadata(); @@ -452,6 +458,7 @@ impl CacheCodecImpl for PartitionEntry { distance_type: distance_type_to_u8(distance_type), num_bits: metadata.num_bits, code_dim: metadata.code_dim, + query_estimator: metadata.query_estimator, rotation_type: rotation_type_to_u8(metadata.rotation_type), fast_rotation_signs: metadata.fast_rotation_signs.clone(), }; @@ -506,6 +513,7 @@ impl CacheCodecImpl for PartitionEntry { num_bits: header.num_bits, // The storage batch already has packed codes; skip re-packing. packed: true, + query_estimator: header.query_estimator, }; let storage = ::Storage::try_from_batch( storage_batch, @@ -1047,6 +1055,7 @@ mod tests { assert_eq!(rm.num_bits, m.num_bits); assert_eq!(rm.code_dim, m.code_dim); assert_eq!(rm.rotation_type, m.rotation_type); + assert_eq!(rm.query_estimator, m.query_estimator); assert_eq!(rm.fast_rotation_signs, m.fast_rotation_signs); assert!(rm.packed); assert_eq!( diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index b47b00d409c..7901390f2b2 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -38,6 +38,7 @@ use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::{LocalMetricsCollector, MetricsCollector, NoOpMetricsCollector}; use lance_index::vector::VectorIndexCacheEntry; use lance_index::vector::bq::builder::RabitQuantizer; +use lance_index::vector::bq::storage::RabitQueryEstimator; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::graph::OrderedNode; use lance_index::vector::hnsw::HNSW; @@ -48,7 +49,8 @@ use lance_index::vector::quantizer::{ }; use lance_index::vector::sq::ScalarQuantizer; use lance_index::vector::storage::{ - QueryResidual, QueryScratch, QueryScratchCapacity, QueryScratchPool, VectorStore, + QueryResidual, QueryScratch, QueryScratchCapacity, QueryScratchPool, RabitRawQueryContext, + VectorStore, }; use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ @@ -114,10 +116,18 @@ struct PreparedPartitionSearch { pre_filter: Arc, partition_id: usize, partition_centroid: Option, + rotated_partition_centroid: Option>, + raw_query_context: Option>, part_entry: Arc, _marker: PhantomData<(S, Q)>, } +#[derive(Debug)] +struct RabitSearchCache { + rotated_centroids: Vec, + code_dim: usize, +} + impl DeepSizeOf for IvfIndexState { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { self.index_file_path.deep_size_of_children(context) @@ -531,7 +541,9 @@ pub struct IVFIndex { io_parallelism: usize, scratch_pool: Arc, + use_query_residual: bool, use_residual_scratch: bool, + rq_search_cache: Option>, _marker: PhantomData<(S, Q)>, } @@ -545,6 +557,11 @@ impl DeepSizeOf for IVFIndex { + self.uuid.deep_size_of_children(context) + self.storage.deep_size_of_children(context) + self.scratch_pool.deep_size_of_children(context) + + self + .rq_search_cache + .as_ref() + .map(|cache| cache.rotated_centroids.len() * std::mem::size_of::()) + .unwrap_or_default() // Skipping session since it is a weak ref } } @@ -557,21 +574,86 @@ impl IVFIndex { .get("num_bits") .and_then(|value| value.as_u64()) .unwrap_or(1); - if num_bits > 1 { + if num_bits > 1 && self.distance_type == DistanceType::Cosine { return Err(Error::not_supported( - "IVF_RQ num_bits>1 search is not supported until split-code query support is implemented", + "IVF_RQ num_bits>1 cosine search is not supported until raw-query cosine factors are implemented", )); } } Ok(()) } + fn use_query_residual( + storage: &IvfQuantizationStorage, + distance_type: DistanceType, + ) -> bool { + if Q::quantization_type() == QuantizationType::Rabit + && let Ok(Quantizer::Rabit(rq)) = storage.quantizer() + { + return rq.metadata_ref().query_estimator == RabitQueryEstimator::ResidualQuery; + } + Q::use_residual(distance_type) + } + + fn build_rq_search_cache( + ivf: &IvfModel, + storage: &IvfQuantizationStorage, + ) -> Result>> { + if Q::quantization_type() != QuantizationType::Rabit { + return Ok(None); + } + let Quantizer::Rabit(rq) = storage.quantizer()? else { + return Ok(None); + }; + if rq.metadata_ref().query_estimator != RabitQueryEstimator::RawQuery { + return Ok(None); + } + let centroids = ivf + .centroids_array() + .ok_or_else(|| Error::index("IVF_RQ raw-query search requires centroids"))?; + let rotated_centroids = rq.rotate_fsl_to_f32(centroids)?; + Ok(Some(Arc::new(RabitSearchCache { + rotated_centroids, + code_dim: rq.code_dim(), + }))) + } + + fn rotated_partition_centroid(&self, partition_id: usize) -> Option> { + let cache = self.rq_search_cache.as_ref()?; + let start = partition_id.checked_mul(cache.code_dim)?; + let end = start.checked_add(cache.code_dim)?; + cache + .rotated_centroids + .get(start..end) + .map(|centroid| centroid.to_vec()) + } + + fn prepare_rq_raw_query_context( + &self, + query: &ArrayRef, + ) -> Result>> { + if Q::quantization_type() != QuantizationType::Rabit || self.use_query_residual { + return Ok(None); + } + let Quantizer::Rabit(rq) = self.storage.quantizer()? else { + return Ok(None); + }; + if rq.metadata_ref().query_estimator != RabitQueryEstimator::RawQuery { + return Ok(None); + } + Ok(Some(Arc::new( + rq.metadata_ref() + .prepare_raw_query_context(query.as_ref())?, + ))) + } + async fn prepare_partition( &self, partition_id: usize, query: &Query, pre_filter: Arc, metrics: &dyn MetricsCollector, + raw_query_context: Option>, ) -> Result> { let (part_entry, ()) = tokio::try_join!( self.load_partition(partition_id, true, metrics), @@ -582,6 +664,8 @@ impl IVFIndex { pre_filter, partition_id, partition_centroid: self.ivf.centroid(partition_id), + rotated_partition_centroid: self.rotated_partition_centroid(partition_id), + raw_query_context, part_entry, _marker: PhantomData, }) @@ -593,6 +677,7 @@ impl IVFIndex { query: &Query, pre_filter: Arc, metrics: &dyn MetricsCollector, + raw_query_context: Option>, ) -> Result> { let part_entry = self.load_partition(partition_id, true, metrics).await?; Ok(PreparedPartitionSearch { @@ -600,13 +685,15 @@ impl IVFIndex { pre_filter, partition_id, partition_centroid: self.ivf.centroid(partition_id), + rotated_partition_centroid: self.rotated_partition_centroid(partition_id), + raw_query_context, part_entry, _marker: PhantomData, }) } fn run_prepared_partition_search( - distance_type: DistanceType, + use_query_residual: bool, use_residual_scratch: bool, prepared: PreparedPartitionSearch, metrics: &dyn MetricsCollector, @@ -617,16 +704,21 @@ impl IVFIndex { pre_filter, partition_id, partition_centroid, + rotated_partition_centroid, + raw_query_context, part_entry, _marker: _, } = prepared; - let residual = Self::residual_for_scratch( + let residual = Self::query_context_for_scratch( + use_query_residual, use_residual_scratch, partition_id, partition_centroid.as_ref(), + rotated_partition_centroid.as_deref(), + raw_query_context.as_deref(), )?; let query = Self::preprocess_partition_query_owned( - distance_type, + use_query_residual, use_residual_scratch, partition_id, partition_centroid.as_ref(), @@ -656,7 +748,7 @@ impl IVFIndex { #[allow(clippy::too_many_arguments)] fn accumulate_prepared_partition_search( - distance_type: DistanceType, + use_query_residual: bool, use_residual_scratch: bool, prepared: PreparedPartitionSearch, heap: &mut BinaryHeap>, @@ -668,16 +760,21 @@ impl IVFIndex { pre_filter, partition_id, partition_centroid, + rotated_partition_centroid, + raw_query_context, part_entry, _marker: _, } = prepared; - let residual = Self::residual_for_scratch( + let residual = Self::query_context_for_scratch( + use_query_residual, use_residual_scratch, partition_id, partition_centroid.as_ref(), + rotated_partition_centroid.as_deref(), + raw_query_context.as_deref(), )?; let query = Self::preprocess_partition_query_owned( - distance_type, + use_query_residual, use_residual_scratch, partition_id, partition_centroid.as_ref(), @@ -705,16 +802,26 @@ impl IVFIndex { ) } - fn residual_for_scratch<'a>( + fn query_context_for_scratch<'a>( + use_query_residual: bool, use_residual_scratch: bool, partition_id: usize, partition_centroid: Option<&'a ArrayRef>, + rotated_partition_centroid: Option<&'a [f32]>, + raw_query_context: Option<&'a RabitRawQueryContext>, ) -> Result>> { if use_residual_scratch { let partition_centroid = partition_centroid.ok_or_else(|| { Error::index(format!("partition centroid {partition_id} does not exist")) })?; Ok(Some(QueryResidual::Centroid(partition_centroid.as_ref()))) + } else if !use_query_residual + && (rotated_partition_centroid.is_some() || raw_query_context.is_some()) + { + Ok(Some(QueryResidual::RabitRawQuery { + rotated_centroid: rotated_partition_centroid, + query: raw_query_context, + })) } else { Ok(None) } @@ -732,14 +839,14 @@ impl IVFIndex { } fn preprocess_partition_query( - distance_type: DistanceType, + use_query_residual: bool, use_residual_scratch: bool, partition_id: usize, partition_centroid: Option<&ArrayRef>, query: &Query, ) -> Result { Self::preprocess_partition_query_owned( - distance_type, + use_query_residual, use_residual_scratch, partition_id, partition_centroid, @@ -748,13 +855,13 @@ impl IVFIndex { } fn preprocess_partition_query_owned( - distance_type: DistanceType, + use_query_residual: bool, use_residual_scratch: bool, partition_id: usize, partition_centroid: Option<&ArrayRef>, mut query: Query, ) -> Result { - if Q::use_residual(distance_type) { + if use_query_residual { let partition_centroid = partition_centroid.ok_or_else(|| { Error::index(format!("partition centroid {partition_id} does not exist")) })?; @@ -774,19 +881,20 @@ impl IVFIndex { let dim = ivf.dimension(); let dist_table_len = dim * 4; + let max_ex_dist_table_len = dim * 256; let max_partition_len = ivf.lengths.iter().copied().max().unwrap_or_default() as usize; QueryScratchCapacity::new( max_partition_len, - dim + dist_table_len, + dim + dist_table_len + max_ex_dist_table_len, max_partition_len, dist_table_len, ) } - fn use_residual_scratch(ivf: &IvfModel, distance_type: DistanceType) -> bool { + fn use_residual_scratch(ivf: &IvfModel, use_query_residual: bool) -> bool { Q::quantization_type() == QuantizationType::Rabit - && Q::use_residual(distance_type) + && use_query_residual && ivf .centroids_array() .map(|centroids| centroids.value_type() == DataType::Float32) @@ -904,14 +1012,18 @@ impl IVFIndex { .await; let scratch_pool = Arc::new(Self::query_scratch_pool(&ivf)); - let use_residual_scratch = Self::use_residual_scratch(&ivf, distance_type); + let use_query_residual = Self::use_query_residual(&storage, distance_type); + let use_residual_scratch = Self::use_residual_scratch(&ivf, use_query_residual); + let rq_search_cache = Self::build_rq_search_cache(&ivf, &storage)?; Ok(Self { uri: to_local_path(&uri), index_path: uri.as_ref().to_string(), uuid, scratch_pool, + use_query_residual, use_residual_scratch, + rq_search_cache, ivf, reader: index_reader, storage, @@ -938,13 +1050,17 @@ impl IVFIndex { io_parallelism: usize, ) -> Self { let scratch_pool = Arc::new(Self::query_scratch_pool(&ivf)); - let use_residual_scratch = Self::use_residual_scratch(&ivf, distance_type); + let use_query_residual = Self::use_query_residual(&storage, distance_type); + let use_residual_scratch = Self::use_residual_scratch(&ivf, use_query_residual); + let rq_search_cache = Self::build_rq_search_cache(&ivf, &storage).unwrap_or(None); Self { uri, index_path, uuid, scratch_pool, + use_query_residual, use_residual_scratch, + rq_search_cache, ivf, reader, storage, @@ -1039,7 +1155,7 @@ impl IVFIndex { #[instrument(level = "debug", skip(self))] pub fn preprocess_query(&self, partition_id: usize, query: &Query) -> Result { Self::preprocess_partition_query( - self.distance_type, + self.use_query_residual, self.use_residual_scratch, partition_id, self.ivf.centroid(partition_id).as_ref(), @@ -1225,15 +1341,19 @@ impl VectorIndex for IVFInd let part_entry = self.load_partition(partition_id, true, metrics).await?; pre_filter.wait_for_ready().await?; - let residual_centroid = if self.use_residual_scratch { - Some(self.ivf.centroid(partition_id).ok_or_else(|| { - Error::index(format!("partition centroid {partition_id} does not exist")) - })?) - } else { - None - }; - let query = self.preprocess_query(partition_id, query)?; + let partition_centroid = self.ivf.centroid(partition_id); + let rotated_partition_centroid = self.rotated_partition_centroid(partition_id); + let raw_query_context = self.prepare_rq_raw_query_context(&query.key)?; + let query = Self::preprocess_partition_query( + self.use_query_residual, + self.use_residual_scratch, + partition_id, + partition_centroid.as_ref(), + query, + )?; let scratch_pool = self.scratch_pool.clone(); + let use_query_residual = self.use_query_residual; + let use_residual_scratch = self.use_residual_scratch; let (batch, local_metrics) = spawn_cpu(move || { let param = (&query).into(); let refine_factor = query.refine_factor.unwrap_or(1) as usize; @@ -1245,7 +1365,14 @@ impl VectorIndex for IVFInd .ok_or(Error::internal( "failed to downcast partition entry".to_string(), ))?; - let residual = residual_centroid.as_deref().map(QueryResidual::Centroid); + let residual = Self::query_context_for_scratch( + use_query_residual, + use_residual_scratch, + partition_id, + partition_centroid.as_ref(), + rotated_partition_centroid.as_deref(), + raw_query_context.as_deref(), + )?; let batch = scratch_pool.with_scratch(|scratch| { part.index.search_with_scratch( query.key, @@ -1275,8 +1402,9 @@ impl VectorIndex for IVFInd metrics: &dyn MetricsCollector, ) -> Result { self.ensure_search_supported()?; + let raw_query_context = self.prepare_rq_raw_query_context(&query.key)?; Ok(Box::new( - self.prepare_partition(partition_id, query, pre_filter, metrics) + self.prepare_partition(partition_id, query, pre_filter, metrics, raw_query_context) .await?, )) } @@ -1292,7 +1420,7 @@ impl VectorIndex for IVFInd .map_err(|_| Error::internal("failed to downcast prepared partition search"))?; self.scratch_pool.with_scratch(|scratch| { Self::run_prepared_partition_search( - self.distance_type, + self.use_query_residual, self.use_residual_scratch, *prepared, metrics, @@ -1341,12 +1469,14 @@ impl VectorIndex for IVFInd } let prepare_parallelism = get_num_compute_intensive_cpus().max(1); + let raw_query_context = self.prepare_rq_raw_query_context(&query.key)?; if control.is_none() && S::supports_global_topk_heap() { let heap_capacity = query.k * query.refine_factor.unwrap_or(1) as usize; pre_filter.wait_for_ready().await?; let prepare_index = self.clone(); let prepare_metrics = metrics.clone(); + let prepare_raw_query_context = raw_query_context.clone(); let prepared = stream::iter(start_idx..end_idx) .map(move |idx| { let part_id = partitions.value(idx); @@ -1355,6 +1485,7 @@ impl VectorIndex for IVFInd let index = prepare_index.clone(); let pre_filter = pre_filter.clone(); let metrics = prepare_metrics.clone(); + let raw_query_context = prepare_raw_query_context.clone(); async move { index .prepare_partition_without_prefilter_wait( @@ -1362,6 +1493,7 @@ impl VectorIndex for IVFInd &query, pre_filter, metrics.as_ref(), + raw_query_context, ) .await } @@ -1370,7 +1502,7 @@ impl VectorIndex for IVFInd .try_collect::>() .await?; - let distance_type = self.distance_type; + let use_query_residual = self.use_query_residual; let use_residual_scratch = self.use_residual_scratch; let search_metrics = metrics.clone(); let scratch_pool = self.scratch_pool.clone(); @@ -1379,7 +1511,7 @@ impl VectorIndex for IVFInd scratch_pool.with_scratch(|scratch| -> DataFusionResult<()> { for prepared in prepared { Self::accumulate_prepared_partition_search( - distance_type, + use_query_residual, use_residual_scratch, prepared, &mut heap, @@ -1406,6 +1538,7 @@ impl VectorIndex for IVFInd let prepare_index = self.clone(); let prepare_metrics = metrics.clone(); + let prepare_raw_query_context = raw_query_context.clone(); tokio::spawn(async move { let prepare_stream = stream::iter(start_idx..end_idx) .map(move |idx| { @@ -1415,6 +1548,7 @@ impl VectorIndex for IVFInd let index = prepare_index.clone(); let pre_filter = pre_filter.clone(); let metrics = prepare_metrics.clone(); + let raw_query_context = prepare_raw_query_context.clone(); async move { index .prepare_partition( @@ -1422,6 +1556,7 @@ impl VectorIndex for IVFInd &query, pre_filter, metrics.as_ref(), + raw_query_context, ) .await } @@ -1437,7 +1572,7 @@ impl VectorIndex for IVFInd } }); - let distance_type = self.distance_type; + let use_query_residual = self.use_query_residual; let use_residual_scratch = self.use_residual_scratch; let search_metrics = metrics.clone(); let batch_tx_for_search = batch_tx.clone(); @@ -1465,7 +1600,7 @@ impl VectorIndex for IVFInd let batch = { Self::run_prepared_partition_search( - distance_type, + use_query_residual, use_residual_scratch, prepared, search_metrics.as_ref(), @@ -1696,8 +1831,8 @@ mod tests { use lance_arrow::FixedSizeListArrayExt; use lance_index::vector::bq::{ RQBuildParams, RQRotationType, - storage::{RABIT_EX_CODE_COLUMN, RabitQuantizationMetadata}, - transform::EX_SCALE_FACTORS_COLUMN, + storage::{RABIT_EX_CODE_COLUMN, RabitQuantizationMetadata, RabitQueryEstimator}, + transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}, }; use lance_index::vector::storage::VectorStore; @@ -1716,7 +1851,7 @@ mod tests { }; use lance_core::cache::LanceCache; use lance_core::utils::tempfile::TempStrDir; - use lance_core::{Error, ROW_ID, Result}; + use lance_core::{ROW_ID, Result}; use lance_encoding::decoder::DecoderPlugins; use lance_file::reader::{FileReader, FileReaderOptions}; use lance_file::writer::FileWriter; @@ -3997,8 +4132,7 @@ mod tests { } #[tokio::test] - #[ignore = "IVF_RQ num_bits>1 creation is gated until split-code search support is implemented"] - async fn test_build_ivf_rq_multi_bit_persists_split_codes_and_gates_search() { + async fn test_build_ivf_rq_multi_bit_persists_split_codes_and_searches() { let test_dir = TempStrDir::default(); let test_uri = test_dir.as_str(); let (mut dataset, vectors) = generate_test_dataset::(test_uri, 0.0..1.0).await; @@ -4018,6 +4152,7 @@ mod tests { let index_uuid = indices[0].uuid.to_string(); let rq_meta = get_rq_metadata(&dataset, scheduler.clone(), &index_uuid).await; assert_eq!(rq_meta.num_bits, 9); + assert_eq!(rq_meta.query_estimator, RabitQueryEstimator::RawQuery); let reader = open_rq_aux_reader(&dataset, scheduler, &index_uuid).await; let schema = reader.schema(); @@ -4026,22 +4161,18 @@ mod tests { panic!("RQ ex-code field should be FixedSizeList"); }; assert_eq!(ex_code_bytes, 32); + assert!(schema.field(EX_ADD_FACTORS_COLUMN).is_some()); assert!(schema.field(EX_SCALE_FACTORS_COLUMN).is_some()); let query = vectors.value(0); - let err = dataset + let results = dataset .scan() .nearest("vector", query.as_primitive::(), 10) .unwrap() .try_into_batch() .await - .unwrap_err(); - assert!(matches!(err, Error::Execution { .. }), "{err}"); - assert!( - err.to_string() - .contains("num_bits>1 search is not supported"), - "{err}" - ); + .unwrap(); + assert_eq!(results.num_rows(), 10); } #[rstest] From a77fb44b7e1ba14d0131a4a1961a411ce9f10a09 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 4 Jun 2026 11:41:30 +0800 Subject: [PATCH 02/14] feat(index): add ivf rq raw-query factors --- rust/lance-index/src/vector/bq/builder.rs | 194 +-- rust/lance-index/src/vector/bq/storage.rs | 1147 +++++++++++++++-- .../src/vector/distributed/index_merger.rs | 19 +- 3 files changed, 1122 insertions(+), 238 deletions(-) diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index 8f4f3bc03bd..ea3d9ed2667 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -21,11 +21,12 @@ use crate::vector::bq::storage::{ RabitQueryEstimator, rabit_binary_code_field, rabit_ex_code_field, }; use crate::vector::bq::transform::{ - ADD_FACTORS_FIELD, EX_ADD_FACTORS_FIELD, EX_SCALE_FACTORS_FIELD, SCALE_FACTORS_FIELD, + ADD_FACTORS_FIELD, ERROR_FACTORS_FIELD, EX_ADD_FACTORS_FIELD, EX_SCALE_FACTORS_FIELD, + SCALE_FACTORS_FIELD, }; use crate::vector::bq::{ RQBuildParams, RQRotationType, rabit_binary_code_bytes, rabit_ex_bits, rabit_ex_code_bytes, - rotation::{apply_fast_rotation, fast_rotation_signs_len, random_fast_rotation_signs}, + rotation::{apply_fast_rotation, random_fast_rotation_signs}, validate_rq_num_bits, }; use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams}; @@ -640,46 +641,6 @@ impl Quantization for RabitQuantizer { )); } - // Reuse a supplied rotation instead of generating a fresh random one. - if let Some(meta) = ¶ms.rotation { - let expected_code_dim = dim * params.num_bits as usize; - if meta.num_bits != params.num_bits || meta.code_dim as usize != expected_code_dim { - return Err(Error::invalid_input(format!( - "supplied RaBitQ rotation does not match build params: rotation \ - num_bits={}, code_dim={}; expected num_bits={}, code_dim={}", - meta.num_bits, meta.code_dim, params.num_bits, expected_code_dim - ))); - } - - match meta.rotation_type { - RQRotationType::Fast => { - let signs = meta.fast_rotation_signs.as_ref().ok_or_else(|| { - Error::invalid_input("supplied fast RaBitQ rotation is missing signs") - })?; - let expected_len = fast_rotation_signs_len(meta.code_dim as usize); - if signs.len() != expected_len { - return Err(Error::invalid_input(format!( - "supplied fast RaBitQ rotation signs length {} does not match \ - expected {} for code_dim={}", - signs.len(), - expected_len, - meta.code_dim - ))); - } - } - RQRotationType::Matrix => { - if meta.rotate_mat.is_none() { - return Err(Error::invalid_input( - "use the fast rotation for distributed builds", - )); - } - } - } - return Ok(Self { - metadata: meta.clone(), - }); - } - let q = match data.as_fixed_size_list().value_type() { DataType::Float16 => Self::new_with_rotation::( params.num_bits, @@ -776,6 +737,9 @@ impl Quantization for RabitQuantizer { fn extra_fields(&self) -> Vec { let mut fields = vec![ADD_FACTORS_FIELD.clone(), SCALE_FACTORS_FIELD.clone()]; + if self.metadata.query_estimator == RabitQueryEstimator::RawQuery { + fields.push(ERROR_FACTORS_FIELD.clone()); + } if let Some(ex_code_field) = rabit_ex_code_field(self.code_dim(), self.metadata.num_bits) .expect("RabitQ num_bits should be validated") { @@ -875,6 +839,8 @@ mod tests { use lance_linalg::distance::DistanceType; use rstest::rstest; + use crate::vector::bq::storage::RABIT_EX_CODE_COLUMN; + #[rstest] #[case(8)] #[case(16)] @@ -939,133 +905,45 @@ mod tests { } #[test] - fn test_rabit_quantizer_requires_dim_divisible_by_8() { - let vectors = Float32Array::from(vec![0.0f32; 4 * 30]); - let fsl = FixedSizeListArray::try_new_from_values(vectors, 30).unwrap(); - let params = RQBuildParams::new(1); - - let err = RabitQuantizer::build(&fsl, DistanceType::L2, ¶ms).unwrap_err(); + fn test_rabit_quantizer_extra_fields_include_raw_query_error_factor() { + let q = RabitQuantizer::new_with_rotation::(1, 128, RQRotationType::Fast); + let fields = q.extra_fields(); assert!( - err.to_string() - .contains("vector dimension must be divisible by 8 for IVF_RQ"), - "{}", - err - ); - } - - fn sample_fsl(n: usize, dim: usize) -> FixedSizeListArray { - let values: Vec = (0..n * dim).map(|i| ((i * 31 % 17) as f32) - 8.0).collect(); - FixedSizeListArray::try_new_from_values(Float32Array::from(values), dim as i32).unwrap() - } - - fn quantized_codes(q: &RabitQuantizer, data: &FixedSizeListArray) -> Vec { - use arrow::datatypes::UInt8Type; - q.quantize(data) - .unwrap() - .as_fixed_size_list() - .values() - .as_primitive::() - .values() - .to_vec() - } - - #[test] - fn test_shared_fast_rotation_gives_identical_codes() { - let dim = 32; - let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Fast); - let json = serde_json::to_string(&seed.metadata(None)).unwrap(); - let meta: RabitQuantizationMetadata = serde_json::from_str(&json).unwrap(); - - let params = RQBuildParams { - num_bits: 1, - rotation_type: RQRotationType::Fast, - rotation: Some(meta), - }; - let data = sample_fsl(8, dim as usize); - let q_a = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); - let q_b = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); - - assert_eq!( - quantized_codes(&q_a, &data), - quantized_codes(&q_b, &data), - "shared rotation must yield identical codes" + fields + .iter() + .any(|field| field.name() == ERROR_FACTORS_FIELD.name()) ); - } - - #[test] - fn test_unpinned_rotation_gives_different_codes() { - let dim = 32; - let params = RQBuildParams::new(1); - let data = sample_fsl(8, dim as usize); - let q_a = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); - let q_b = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); - - assert_ne!( - quantized_codes(&q_a, &data), - quantized_codes(&q_b, &data), - "independent unpinned rotations must yield different codes" - ); - } - - #[test] - fn test_build_rejects_rotation_with_mismatched_code_dim() { - let seed = RabitQuantizer::new_with_rotation::(1, 16, RQRotationType::Fast); - let params = RQBuildParams { - num_bits: 1, - rotation_type: RQRotationType::Fast, - rotation: Some(seed.metadata(None)), - }; - let data = sample_fsl(4, 32); - let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); assert!( - err.to_string().contains("does not match build params"), - "{}", - err + !fields + .iter() + .any(|field| field.name() == RABIT_EX_CODE_COLUMN) ); - } - #[test] - fn test_build_rejects_fast_rotation_with_bad_signs_length() { - let dim = 16; - let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Fast); - let mut meta = seed.metadata(None); - // Corrupt the signs to the wrong length (valid would be 4 * ceil(16/8) = 8). - meta.fast_rotation_signs = Some(vec![0u8; 7]); - let params = RQBuildParams { - num_bits: 1, - rotation_type: RQRotationType::Fast, - rotation: Some(meta), - }; - let data = sample_fsl(4, dim as usize); - let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); - assert!(err.to_string().contains("signs length"), "{}", err); + let q = RabitQuantizer::new_with_rotation::(3, 128, RQRotationType::Fast); + let fields = q.extra_fields(); + for expected in [ + ERROR_FACTORS_FIELD.name().as_str(), + RABIT_EX_CODE_COLUMN, + EX_ADD_FACTORS_FIELD.name().as_str(), + EX_SCALE_FACTORS_FIELD.name().as_str(), + ] { + assert!( + fields.iter().any(|field| field.name().as_str() == expected), + "missing {expected}" + ); + } } #[test] - fn test_matrix_rotation_lost_through_json_is_rejected() { - let dim = 16; - let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Matrix); - let meta = seed.metadata(None); - assert!(meta.rotate_mat.is_some()); - - let json = serde_json::to_string(&meta).unwrap(); - let parsed: RabitQuantizationMetadata = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed.rotation_type, RQRotationType::Matrix); - assert!( - parsed.rotate_mat.is_none(), - "matrix is expected to be dropped by JSON serialization" - ); + fn test_rabit_quantizer_requires_dim_divisible_by_8() { + let vectors = Float32Array::from(vec![0.0f32; 4 * 30]); + let fsl = FixedSizeListArray::try_new_from_values(vectors, 30).unwrap(); + let params = RQBuildParams::new(1); - let params = RQBuildParams { - num_bits: 1, - rotation_type: RQRotationType::Matrix, - rotation: Some(parsed), - }; - let data = sample_fsl(4, dim as usize); - let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + let err = RabitQuantizer::build(&fsl, DistanceType::L2, ¶ms).unwrap_err(); assert!( err.to_string() - .contains("fast rotation for distributed builds"), + .contains("vector dimension must be divisible by 8 for IVF_RQ"), "{}", err ); diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index e723b1abfe9..460552fb1bb 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -2,9 +2,12 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{BinaryHeap, HashMap}; use std::ops::Sub; -use std::sync::Arc; +use std::sync::{ + Arc, OnceLock, + atomic::{AtomicU64, Ordering}, +}; use arrow::array::AsArray; use arrow::datatypes::{Float16Type, Float32Type, Float64Type, UInt8Type, UInt64Type}; @@ -19,7 +22,7 @@ use itertools::Itertools; use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, RecordBatchExt}; use lance_core::{Error, ROW_ID, Result}; use lance_file::previous::reader::FileReader as PreviousFileReader; -use lance_linalg::distance::{DistanceType, Dot, dot}; +use lance_linalg::distance::{DistanceType, Dot, dot, l2::l2}; use lance_linalg::simd::{ self, dist_table::{BATCH_SIZE, PERM0, PERM0_INVERSE}, @@ -39,12 +42,14 @@ use crate::frag_reuse::FragReuseIndex; use crate::pb; use crate::vector::bq::rotation::{apply_fast_rotation, apply_fast_rotation_in_place}; use crate::vector::bq::transform::{ - ADD_FACTORS_COLUMN, EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN, SCALE_FACTORS_COLUMN, + ADD_FACTORS_COLUMN, ERROR_FACTORS_COLUMN, EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN, + SCALE_FACTORS_COLUMN, }; use crate::vector::bq::{ RQRotationType, rabit_binary_code_bytes, rabit_ex_bits, rabit_ex_code_bytes, validate_rq_num_bits, }; +use crate::vector::graph::{OrderedFloat, OrderedNode}; use crate::vector::pq::storage::transpose; use crate::vector::quantizer::{QuantizerMetadata, QuantizerStorage}; use crate::vector::storage::{DistCalculator, QueryResidual, RabitRawQueryContext, VectorStore}; @@ -54,6 +59,129 @@ pub const RABIT_CODE_COLUMN: &str = "_rabit_codes"; pub const RABIT_EX_CODE_COLUMN: &str = "__ex_codes"; pub const SEGMENT_LENGTH: usize = 4; pub const SEGMENT_NUM_CODES: usize = 1 << SEGMENT_LENGTH; +const RABIT_PRUNE_STATS_ENV: &str = "LANCE_RQ_PRUNE_STATS"; +const RABIT_PRUNE_STATS_INTERVAL_ENV: &str = "LANCE_RQ_PRUNE_STATS_INTERVAL"; +const DEFAULT_RABIT_PRUNE_STATS_INTERVAL: u64 = 1024; + +#[derive(Default)] +struct RabitPruneStats { + calls: AtomicU64, + candidates: AtomicU64, + pruned_upper_bound: AtomicU64, + pruned_heap: AtomicU64, + exact: AtomicU64, + exact_rejected: AtomicU64, +} + +#[derive(Default)] +struct RabitPruneBypassStats { + calls: AtomicU64, +} + +static RABIT_PRUNE_STATS: OnceLock = OnceLock::new(); +static RABIT_PRUNE_BYPASS_STATS: OnceLock = OnceLock::new(); +static RABIT_PRUNE_STATS_ENABLED: OnceLock = OnceLock::new(); +static RABIT_PRUNE_STATS_INTERVAL: OnceLock = OnceLock::new(); + +fn rabit_prune_stats_enabled() -> bool { + *RABIT_PRUNE_STATS_ENABLED.get_or_init(|| match std::env::var(RABIT_PRUNE_STATS_ENV) { + Ok(value) => { + let value = value.to_ascii_lowercase(); + !matches!(value.as_str(), "" | "0" | "false" | "off" | "no") + } + Err(_) => false, + }) +} + +fn rabit_prune_stats_interval() -> u64 { + *RABIT_PRUNE_STATS_INTERVAL.get_or_init(|| { + std::env::var(RABIT_PRUNE_STATS_INTERVAL_ENV) + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|interval| *interval > 0) + .unwrap_or(DEFAULT_RABIT_PRUNE_STATS_INTERVAL) + }) +} + +fn ratio(numerator: u64, denominator: u64) -> f64 { + if denominator == 0 { + 0.0 + } else { + numerator as f64 / denominator as f64 + } +} + +fn emit_rabit_prune_stats(message: &str) { + log::warn!( + target: "lance_index::vector::bq::prune_stats", + "{}", + message + ); +} + +fn record_rabit_prune_stats( + candidates: usize, + pruned_upper_bound: usize, + pruned_heap: usize, + exact: usize, + exact_rejected: usize, +) { + if !rabit_prune_stats_enabled() { + return; + } + + let stats = RABIT_PRUNE_STATS.get_or_init(RabitPruneStats::default); + let calls = stats.calls.fetch_add(1, Ordering::Relaxed) + 1; + let candidates = stats + .candidates + .fetch_add(candidates as u64, Ordering::Relaxed) + + candidates as u64; + let pruned_upper_bound = stats + .pruned_upper_bound + .fetch_add(pruned_upper_bound as u64, Ordering::Relaxed) + + pruned_upper_bound as u64; + let pruned_heap = stats + .pruned_heap + .fetch_add(pruned_heap as u64, Ordering::Relaxed) + + pruned_heap as u64; + let exact = stats.exact.fetch_add(exact as u64, Ordering::Relaxed) + exact as u64; + let exact_rejected = stats + .exact_rejected + .fetch_add(exact_rejected as u64, Ordering::Relaxed) + + exact_rejected as u64; + let interval = rabit_prune_stats_interval(); + if calls.is_multiple_of(interval) { + let pruned = pruned_upper_bound + pruned_heap; + emit_rabit_prune_stats(&format!( + "ivf_rq_prune_stats calls={} candidates={} pruned={} pruned_upper_bound={} pruned_heap={} prune_ratio={:.6} exact={} exact_ratio={:.6} exact_rejected={} exact_reject_ratio={:.6}", + calls, + candidates, + pruned, + pruned_upper_bound, + pruned_heap, + ratio(pruned, candidates), + exact, + ratio(exact, candidates), + exact_rejected, + ratio(exact_rejected, exact), + )); + } +} + +fn record_rabit_prune_bypass(reason: &'static str) { + if !rabit_prune_stats_enabled() { + return; + } + + let stats = RABIT_PRUNE_BYPASS_STATS.get_or_init(RabitPruneBypassStats::default); + let calls = stats.calls.fetch_add(1, Ordering::Relaxed) + 1; + if calls.is_multiple_of(rabit_prune_stats_interval()) { + emit_rabit_prune_stats(&format!( + "ivf_rq_prune_stats_bypass calls={} reason={}", + calls, reason + )); + } +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -330,14 +458,22 @@ pub struct RabitQuantizationStorage { codes: FixedSizeListArray, add_factors: Float32Array, scale_factors: Float32Array, + error_factors: Option, ex_codes: Option, + packed_ex_codes: Option, ex_add_factors: Option, ex_scale_factors: Option, } impl DeepSizeOf for RabitQuantizationStorage { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { - self.metadata.deep_size_of_children(context) + self.batch.get_array_memory_size() + self.metadata.deep_size_of_children(context) + + self.batch.get_array_memory_size() + + self + .packed_ex_codes + .as_ref() + .map(|codes| codes.get_array_memory_size()) + .unwrap_or_default() } } @@ -376,6 +512,25 @@ impl RabitQuantizationStorage { } } + fn raw_query_error( + &self, + dist_q_c: f32, + rotated_query: &[f32], + rotated_centroid: Option<&[f32]>, + ) -> f32 { + match self.distance_type { + DistanceType::L2 => dist_q_c.max(0.0).sqrt(), + DistanceType::Dot => rotated_centroid + .map(|centroid| l2(rotated_query, centroid).sqrt()) + .unwrap_or_else(|| dist_q_c.max(0.0).sqrt()), + DistanceType::Cosine => dist_q_c.max(0.0).sqrt(), + _ => unimplemented!( + "RabitQ does not support distance type: {}", + self.distance_type + ), + } + } + fn distance_calculator_from_parts<'a>( &'a self, dim: usize, @@ -383,11 +538,16 @@ impl RabitQuantizationStorage { ex_dist_table: Cow<'a, [f32]>, sum_q: f32, query_factor: f32, + query_error: f32, ) -> RabitDistCalculator<'a> { let ex_codes = self .ex_codes .as_ref() .map(|codes| codes.values().as_primitive::().values().as_ref()); + let packed_ex_codes = self + .packed_ex_codes + .as_ref() + .map(|codes| codes.values().as_primitive::().values().as_ref()); RabitDistCalculator::new( dim, self.metadata.num_bits, @@ -399,13 +559,18 @@ impl RabitQuantizationStorage { ex_codes, self.add_factors.values(), self.scale_factors.values(), + self.error_factors + .as_ref() + .map(|factors| factors.values().as_ref()), self.ex_add_factors .as_ref() .map(|factors| factors.values().as_ref()), self.ex_scale_factors .as_ref() .map(|factors| factors.values().as_ref()), + packed_ex_codes, query_factor, + query_error, ) } @@ -586,9 +751,12 @@ pub struct RabitDistCalculator<'a> { ex_dist_table: Cow<'a, [f32]>, add_factors: &'a [f32], scale_factors: &'a [f32], + error_factors: Option<&'a [f32]>, ex_add_factors: Option<&'a [f32]>, ex_scale_factors: Option<&'a [f32]>, + packed_ex_codes: Option<&'a [u8]>, query_factor: f32, + query_error: f32, sum_q: f32, sqrt_d: f32, @@ -607,9 +775,12 @@ impl<'a> RabitDistCalculator<'a> { ex_codes: Option<&'a [u8]>, add_factors: &'a [f32], scale_factors: &'a [f32], + error_factors: Option<&'a [f32]>, ex_add_factors: Option<&'a [f32]>, ex_scale_factors: Option<&'a [f32]>, + packed_ex_codes: Option<&'a [u8]>, query_factor: f32, + query_error: f32, ) -> Self { Self { dim, @@ -621,13 +792,315 @@ impl<'a> RabitDistCalculator<'a> { ex_dist_table, add_factors, scale_factors, + error_factors, ex_add_factors, ex_scale_factors, + packed_ex_codes, query_factor, + query_error, sqrt_d: (dim as f32 * num_bits as f32).sqrt(), sum_q, } } + + #[allow(clippy::uninit_vec)] + fn binary_distances_with_scratch( + &self, + n: usize, + code_len: usize, + dists: &mut Vec, + quantized_dists: &mut Vec, + quantized_dists_table: &mut Vec, + ) -> usize { + let (qmin, qmax) = quantize_dist_table_into(&self.dist_table, quantized_dists_table); + let remainder = n % BATCH_SIZE; + let simd_len = n - remainder; + quantized_dists.clear(); + quantized_dists.reserve(simd_len); + // SAFETY: sum_4bit_dist_table overwrites each element in the SIMD batch range. + unsafe { + quantized_dists.set_len(simd_len); + } + simd::dist_table::sum_4bit_dist_table( + simd_len, + code_len, + self.codes, + quantized_dists_table, + quantized_dists, + ); + + let range = (qmax - qmin) / 255.0; + let num_tables = quantized_dists_table.len() / SEGMENT_NUM_CODES; + let sum_min = num_tables as f32 * qmin; + dists.clear(); + dists.reserve(n); + // SAFETY: the SIMD section below writes [0, simd_len), and the + // remainder section writes [simd_len, n). + unsafe { + dists.set_len(n); + } + let (simd_dists, remainder_dists) = dists.split_at_mut(simd_len); + simd_dists + .iter_mut() + .zip(quantized_dists.iter()) + .for_each(|(dist, q_dist)| { + *dist = (*q_dist as f32) * range + sum_min; + }); + + remainder_dists + .iter_mut() + .enumerate() + .for_each(|(id, dist)| { + *dist = compute_single_rq_distance( + self.codes, + simd_len + id, + n, + code_len, + &self.dist_table, + ); + }); + simd_len + } + + #[allow(clippy::uninit_vec)] + fn apply_raw_query_multi_bit_distances( + &self, + simd_len: usize, + dists: &mut [f32], + quantized_dists: &mut Vec, + quantized_dists_table: &mut Vec, + ) { + let ex_bits = self.num_bits - 1; + let ex_codes = self + .ex_codes + .expect("raw-query multi-bit RQ requires ex codes"); + let ex_add_factors = self + .ex_add_factors + .expect("raw-query multi-bit RQ requires ex add factors"); + let ex_scale_factors = self + .ex_scale_factors + .expect("raw-query multi-bit RQ requires ex scale factors"); + let ex_code_len = + rabit_ex_code_bytes(self.dim, ex_bits).expect("RabitQ num_bits should be validated"); + let code_scale = (1u32 << ex_bits) as f32; + let code_bias = -(code_scale - 0.5); + + let fastscan_len = if simd_len > 0 && supports_ex_fastscan(ex_bits) { + self.packed_ex_codes + .map(|packed_ex_codes| { + let fastscan_len = simd_len; + let fastscan_code_len = ex_fastscan_code_len(self.dim, ex_bits) + .expect("RabitQ num_bits should be validated"); + let (qmin, qmax, quantization_max) = quantize_ex_fastscan_dist_table_into( + self.dim, + ex_bits, + &self.ex_dist_table, + quantized_dists_table, + ); + quantized_dists.clear(); + quantized_dists.reserve(fastscan_len); + // SAFETY: sum_4bit_dist_table overwrites each element in the SIMD batch range. + unsafe { + quantized_dists.set_len(fastscan_len); + } + simd::dist_table::sum_4bit_dist_table( + fastscan_len, + fastscan_code_len, + packed_ex_codes, + quantized_dists_table, + quantized_dists, + ); + + let range = (qmax - qmin) / quantization_max; + let num_tables = quantized_dists_table.len() / SEGMENT_NUM_CODES; + let sum_min = num_tables as f32 * qmin; + dists + .iter_mut() + .take(fastscan_len) + .zip(quantized_dists.iter()) + .enumerate() + .for_each(|(id, (dist, q_ex_dist))| { + let ex_dist = (*q_ex_dist as f32) * range + sum_min; + let full_dot = code_scale * *dist + ex_dist + code_bias * self.sum_q; + *dist = full_dot * ex_scale_factors[id] + + ex_add_factors[id] + + self.query_factor; + }); + fastscan_len + }) + .unwrap_or_default() + } else { + 0 + }; + + dists + .iter_mut() + .enumerate() + .skip(fastscan_len) + .for_each(|(id, dist)| { + let ex_dist = compute_single_rq_ex_distance( + ex_codes, + id, + ex_code_len, + ex_bits, + self.dim, + &self.ex_dist_table, + ); + let full_dot = code_scale * *dist + ex_dist + code_bias * self.sum_q; + *dist = full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor; + }); + } + + #[inline] + fn raw_query_binary_distance(&self, id: usize, binary_ip: f32) -> f32 { + (binary_ip - 0.5 * self.sum_q) * self.scale_factors[id] + + self.add_factors[id] + + self.query_factor + } + + #[inline] + fn raw_query_lower_bound(&self, id: usize, binary_ip: f32) -> Option { + let error_factors = self.error_factors?; + Some(self.raw_query_binary_distance(id, binary_ip) - error_factors[id] * self.query_error) + } + + #[inline] + #[allow(clippy::too_many_arguments)] + fn raw_query_multi_bit_exact_distance( + &self, + id: usize, + binary_ip: f32, + ex_bits: u8, + ex_code_len: usize, + ex_codes: &[u8], + ex_add_factors: &[f32], + ex_scale_factors: &[f32], + ) -> f32 { + let ex_dist = compute_single_rq_ex_distance( + ex_codes, + id, + ex_code_len, + ex_bits, + self.dim, + &self.ex_dist_table, + ); + let code_bias = -((1u32 << ex_bits) as f32 - 0.5); + let full_dot = (1u32 << ex_bits) as f32 * binary_ip + ex_dist + code_bias * self.sum_q; + full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor + } + + #[allow(clippy::too_many_arguments)] + fn accumulate_raw_query_multi_bit_topk_with_scratch( + &self, + k: usize, + lower_bound: Option, + upper_bound: Option, + row_ids: impl Iterator, + res: &mut BinaryHeap>, + dists: &mut Vec, + quantized_dists: &mut Vec, + quantized_dists_table: &mut Vec, + ) { + let code_len = rabit_binary_code_bytes(self.dim); + let n = self.codes.len() / code_len; + if n == 0 { + dists.clear(); + quantized_dists.clear(); + return; + } + + self.binary_distances_with_scratch( + n, + code_len, + dists, + quantized_dists, + quantized_dists_table, + ); + + let ex_bits = self.num_bits - 1; + let ex_codes = self + .ex_codes + .expect("raw-query multi-bit RQ requires ex codes"); + let ex_add_factors = self + .ex_add_factors + .expect("raw-query multi-bit RQ requires ex add factors"); + let ex_scale_factors = self + .ex_scale_factors + .expect("raw-query multi-bit RQ requires ex scale factors"); + let ex_code_len = + rabit_ex_code_bytes(self.dim, ex_bits).expect("RabitQ num_bits should be validated"); + let query_lower_bound = lower_bound.unwrap_or(f32::MIN); + let query_upper_bound = upper_bound.unwrap_or(f32::MAX); + let mut max_dist = res.peek().map(|node| node.dist); + let mut candidates = 0; + let mut pruned_upper_bound = 0; + let mut pruned_heap = 0; + let mut exact = 0; + let mut exact_rejected = 0; + + for (id, row_id) in row_ids { + let Some(binary_ip) = dists.get(id).copied() else { + continue; + }; + candidates += 1; + let Some(raw_lower_bound) = self.raw_query_lower_bound(id, binary_ip) else { + continue; + }; + if raw_lower_bound >= query_upper_bound { + pruned_upper_bound += 1; + continue; + } + if res.len() >= k && max_dist.is_some_and(|max_dist| raw_lower_bound >= max_dist.0) { + pruned_heap += 1; + continue; + } + + exact += 1; + let dist = self.raw_query_multi_bit_exact_distance( + id, + binary_ip, + ex_bits, + ex_code_len, + ex_codes, + ex_add_factors, + ex_scale_factors, + ); + if dist < query_lower_bound || dist >= query_upper_bound { + exact_rejected += 1; + continue; + } + let dist = OrderedFloat(dist); + if res.len() < k { + res.push(OrderedNode::new(row_id, dist)); + if res.len() == k { + max_dist = res.peek().map(|node| node.dist); + } + } else if max_dist.is_some_and(|max_dist| max_dist > dist) { + res.pop(); + res.push(OrderedNode::new(row_id, dist)); + max_dist = res.peek().map(|node| node.dist); + } + } + record_rabit_prune_stats( + candidates, + pruned_upper_bound, + pruned_heap, + exact, + exact_rejected, + ); + } + + fn raw_query_lower_bound_gating_disabled_reason(&self) -> Option<&'static str> { + if self.query_estimator != RabitQueryEstimator::RawQuery { + Some("residual_query_estimator") + } else if self.num_bits <= 1 { + Some("num_bits_le_one") + } else if self.error_factors.is_none() { + Some("missing_error_factors") + } else { + None + } + } } #[inline] @@ -744,15 +1217,117 @@ fn quantize_dist_table_into(dist_table: &[f32], quantized_dist_table: &mut Vec u8 { debug_assert!(ex_bits > 0); - let mut value = 0u8; let bit_offset = dim_idx * ex_bits as usize; - for bit_idx in 0..ex_bits as usize { - let src_bit = bit_offset + bit_idx; - if (row_codes[src_bit / u8::BITS as usize] >> (src_bit % u8::BITS as usize)) & 1 != 0 { - value |= 1u8 << bit_idx; + let byte_idx = bit_offset / u8::BITS as usize; + let bit_shift = bit_offset % u8::BITS as usize; + let bits = row_codes[byte_idx] as u16 + | row_codes + .get(byte_idx + 1) + .map(|byte| (*byte as u16) << u8::BITS) + .unwrap_or_default(); + let mask = (1u16 << ex_bits) - 1; + ((bits >> bit_shift) & mask) as u8 +} + +fn quantize_ex_fastscan_dist_table_into( + dim: usize, + ex_bits: u8, + ex_dist_table: &[f32], + quantized_dist_table: &mut Vec, +) -> (f32, f32, f32) { + debug_assert!(supports_ex_fastscan(ex_bits)); + + let entries_per_dim = 1usize << ex_bits; + debug_assert_eq!(ex_dist_table.len(), dim * entries_per_dim); + let num_split_tables = + ex_fastscan_code_len(dim, ex_bits).expect("RabitQ num_bits should be validated") * 2; + let quantization_max = (u16::MAX as usize / num_split_tables) + .min(u8::MAX as usize) + .max(1) as f32; + + let mut qmin = f32::INFINITY; + let mut qmax = f32::NEG_INFINITY; + for table_idx in 0..num_split_tables { + for code in 0..SEGMENT_NUM_CODES { + let value = ex_fastscan_dist_table_value(dim, ex_bits, ex_dist_table, table_idx, code); + qmin = qmin.min(value); + qmax = qmax.max(value); } } - value + + quantized_dist_table.clear(); + quantized_dist_table.reserve(num_split_tables * SEGMENT_NUM_CODES); + if qmin == qmax { + quantized_dist_table.resize(num_split_tables * SEGMENT_NUM_CODES, 0); + return (qmin, qmax, quantization_max); + } + + let factor = quantization_max / (qmax - qmin); + for table_idx in 0..num_split_tables { + for code in 0..SEGMENT_NUM_CODES { + let value = ex_fastscan_dist_table_value(dim, ex_bits, ex_dist_table, table_idx, code); + quantized_dist_table.push(((value - qmin) * factor).round() as u8); + } + } + + (qmin, qmax, quantization_max) +} + +#[inline] +fn supports_ex_fastscan(ex_bits: u8) -> bool { + matches!(ex_bits, 2 | 4 | 8) +} + +#[inline] +fn ex_fastscan_code_len(dim: usize, ex_bits: u8) -> Option { + match ex_bits { + 2 | 4 | 8 => rabit_ex_code_bytes(dim, ex_bits).ok(), + _ => None, + } +} + +#[inline] +fn ex_fastscan_dist_table_value( + dim: usize, + ex_bits: u8, + ex_dist_table: &[f32], + table_idx: usize, + code: usize, +) -> f32 { + match ex_bits { + 2 => { + let dim_idx = table_idx * 2; + let low = code & 0b11; + let high = (code >> 2) & 0b11; + ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, low) + + ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx + 1, high) + } + 4 => ex_dist_table_value(ex_dist_table, dim, ex_bits, table_idx, code), + 8 => { + let dim_idx = table_idx / 2; + if table_idx.is_multiple_of(2) { + ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, code) + } else { + ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, code << SEGMENT_LENGTH) + } + } + _ => unreachable!("unsupported RabitQ ex_bits={ex_bits} for FastScan"), + } +} + +#[inline] +fn ex_dist_table_value( + ex_dist_table: &[f32], + dim: usize, + ex_bits: u8, + dim_idx: usize, + code: usize, +) -> f32 { + if dim_idx >= dim { + return 0.0; + } + let entries_per_dim = 1usize << ex_bits; + ex_dist_table[dim_idx * entries_per_dim + code] } #[inline] @@ -777,6 +1352,17 @@ fn compute_single_rq_ex_distance( .sum() } +fn maybe_pack_ex_codes( + ex_codes: Option<&FixedSizeListArray>, + ex_bits: u8, +) -> Option { + let ex_codes = ex_codes?; + match ex_bits { + 2 | 4 | 8 => Some(pack_codes(ex_codes)), + _ => None, + } +} + impl DistCalculator for RabitDistCalculator<'_> { #[inline(always)] fn distance(&self, id: u32) -> f32 { @@ -812,17 +1398,15 @@ impl DistCalculator for RabitDistCalculator<'_> { .expect("raw-query multi-bit RQ requires ex scale factors"); let ex_code_len = rabit_ex_code_bytes(self.dim, ex_bits) .expect("RabitQ num_bits should be validated"); - let ex_dist = compute_single_rq_ex_distance( - ex_codes, + self.raw_query_multi_bit_exact_distance( id, - ex_code_len, + dist, ex_bits, - self.dim, - &self.ex_dist_table, - ); - let code_bias = -((1u32 << ex_bits) as f32 - 0.5); - let full_dot = (1u32 << ex_bits) as f32 * dist + ex_dist + code_bias * self.sum_q; - full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor + ex_code_len, + ex_codes, + ex_add_factors, + ex_scale_factors, + ) } } } @@ -858,73 +1442,187 @@ impl DistCalculator for RabitDistCalculator<'_> { return; } + let simd_len = self.binary_distances_with_scratch( + n, + code_len, + dists, + quantized_dists, + quantized_dists_table, + ); + if self.query_estimator == RabitQueryEstimator::RawQuery && self.num_bits > 1 { - dists.clear(); - dists.reserve(n); - for id in 0..n { - dists.push(self.distance(id as u32)); - } - quantized_dists.clear(); - quantized_dists_table.clear(); + self.apply_raw_query_multi_bit_distances( + simd_len, + dists, + quantized_dists, + quantized_dists_table, + ); return; } - let (qmin, qmax) = quantize_dist_table_into(&self.dist_table, quantized_dists_table); - let remainder = n % BATCH_SIZE; - let simd_len = n - remainder; - quantized_dists.clear(); - quantized_dists.reserve(simd_len); - // SAFETY: sum_4bit_dist_table overwrites each element in the SIMD batch range. - unsafe { - quantized_dists.set_len(simd_len); + dists + .iter_mut() + .enumerate() + .for_each(|(id, dist)| match self.query_estimator { + RabitQueryEstimator::ResidualQuery => { + let dist_vq_qr = (2.0 * *dist - self.sum_q) / self.sqrt_d; + *dist = dist_vq_qr * self.scale_factors[id] + + self.add_factors[id] + + self.query_factor; + } + RabitQueryEstimator::RawQuery => { + let binary_dot = *dist - 0.5 * self.sum_q; + *dist = binary_dot * self.scale_factors[id] + + self.add_factors[id] + + self.query_factor; + } + }); + } + + #[allow(clippy::too_many_arguments)] + fn accumulate_topk_with_scratch( + &self, + k: usize, + lower_bound: Option, + upper_bound: Option, + row_id: impl Fn(u32) -> u64, + res: &mut BinaryHeap>, + dists: &mut Vec, + quantized_dists: &mut Vec, + quantized_dists_table: &mut Vec, + ) { + if k == 0 { + return; } - simd::dist_table::sum_4bit_dist_table( - simd_len, - code_len, - self.codes, + if let Some(reason) = self.raw_query_lower_bound_gating_disabled_reason() { + record_rabit_prune_bypass(reason); + self.distance_all_with_scratch(k, dists, quantized_dists, quantized_dists_table); + accumulate_distances_into_heap(k, lower_bound, upper_bound, row_id, res, dists); + return; + } + + let code_len = rabit_binary_code_bytes(self.dim); + let n = self.codes.len() / code_len; + self.accumulate_raw_query_multi_bit_topk_with_scratch( + k, + lower_bound, + upper_bound, + (0..n).map(|id| (id, row_id(id as u32))), + res, + dists, + quantized_dists, quantized_dists_table, + ); + } + + #[allow(clippy::too_many_arguments)] + fn accumulate_filtered_topk_with_scratch( + &self, + k: usize, + lower_bound: Option, + upper_bound: Option, + row_ids: impl Iterator, + accept_row: impl Fn(u64) -> bool, + res: &mut BinaryHeap>, + dists: &mut Vec, + quantized_dists: &mut Vec, + quantized_dists_table: &mut Vec, + ) { + if k == 0 { + return; + } + if let Some(reason) = self.raw_query_lower_bound_gating_disabled_reason() { + record_rabit_prune_bypass(reason); + self.distance_all_with_scratch(k, dists, quantized_dists, quantized_dists_table); + accumulate_filtered_distances_into_heap( + k, + lower_bound, + upper_bound, + row_ids, + accept_row, + res, + dists, + ); + return; + } + + self.accumulate_raw_query_multi_bit_topk_with_scratch( + k, + lower_bound, + upper_bound, + row_ids + .filter(|(_, row_id)| accept_row(*row_id)) + .map(|(id, row_id)| (id as usize, row_id)), + res, + dists, quantized_dists, + quantized_dists_table, ); + } +} - let range = (qmax - qmin) / 255.0; - let num_tables = quantized_dists_table.len() / 16; - let sum_min = num_tables as f32 * qmin; - dists.clear(); - dists.reserve(n); - // SAFETY: the SIMD section below writes [0, simd_len), and the - // remainder section writes [simd_len, n). - unsafe { - dists.set_len(n); +fn accumulate_distances_into_heap( + k: usize, + lower_bound: Option, + upper_bound: Option, + row_id: impl Fn(u32) -> u64, + res: &mut BinaryHeap>, + dists: &[f32], +) { + let lower_bound = lower_bound.unwrap_or(f32::MIN).into(); + let upper_bound = upper_bound.unwrap_or(f32::MAX).into(); + let mut max_dist = res.peek().map(|node| node.dist); + for (id, dist) in dists.iter().copied().enumerate() { + let dist = OrderedFloat(dist); + if dist < lower_bound || dist >= upper_bound { + continue; } - let (simd_dists, remainder_dists) = dists.split_at_mut(simd_len); - simd_dists - .iter_mut() - .zip(quantized_dists.iter()) - .enumerate() - .for_each(|(id, (dist, q_dist))| { - let dist_vq = (*q_dist as f32) * range + sum_min; - *dist = match self.query_estimator { - RabitQueryEstimator::ResidualQuery => { - let dist_vq_qr = (2.0 * dist_vq - self.sum_q) / self.sqrt_d; - dist_vq_qr * self.scale_factors[id] - + self.add_factors[id] - + self.query_factor - } - RabitQueryEstimator::RawQuery => { - let binary_dot = dist_vq - 0.5 * self.sum_q; - binary_dot * self.scale_factors[id] - + self.add_factors[id] - + self.query_factor - } - }; - }); + if res.len() < k { + res.push(OrderedNode::new(row_id(id as u32), dist)); + if res.len() == k { + max_dist = res.peek().map(|node| node.dist); + } + } else if max_dist.is_some_and(|max_dist| max_dist > dist) { + res.pop(); + res.push(OrderedNode::new(row_id(id as u32), dist)); + max_dist = res.peek().map(|node| node.dist); + } + } +} - remainder_dists - .iter_mut() - .enumerate() - .for_each(|(id, dist)| { - *dist = self.distance((simd_len + id) as u32); - }); +fn accumulate_filtered_distances_into_heap( + k: usize, + lower_bound: Option, + upper_bound: Option, + row_ids: impl Iterator, + accept_row: impl Fn(u64) -> bool, + res: &mut BinaryHeap>, + dists: &[f32], +) { + let lower_bound = lower_bound.unwrap_or(f32::MIN).into(); + let upper_bound = upper_bound.unwrap_or(f32::MAX).into(); + let mut max_dist = res.peek().map(|node| node.dist); + for (id, row_id) in row_ids { + if !accept_row(row_id) { + continue; + } + let Some(dist) = dists.get(id as usize).copied() else { + continue; + }; + let dist = OrderedFloat(dist); + if dist < lower_bound || dist >= upper_bound { + continue; + } + if res.len() < k { + res.push(OrderedNode::new(row_id, dist)); + if res.len() == k { + max_dist = res.peek().map(|node| node.dist); + } + } else if max_dist.is_some_and(|max_dist| max_dist > dist) { + res.pop(); + res.push(OrderedNode::new(row_id, dist)); + max_dist = res.peek().map(|node| node.dist); + } } } @@ -975,6 +1673,10 @@ impl VectorStore for RabitQuantizationStorage { RabitQueryEstimator::ResidualQuery => self.residual_query_factor(dist_q_c), RabitQueryEstimator::RawQuery => self.raw_query_factor(dist_q_c, &rotated_qr, None), }; + let query_error = match self.metadata.query_estimator { + RabitQueryEstimator::ResidualQuery => 0.0, + RabitQueryEstimator::RawQuery => self.raw_query_error(dist_q_c, &rotated_qr, None), + }; let sum_q = rotated_qr.into_iter().sum(); self.distance_calculator_from_parts( @@ -983,6 +1685,7 @@ impl VectorStore for RabitQuantizationStorage { Cow::Owned(ex_dist_table), sum_q, query_factor, + query_error, ) } @@ -1008,12 +1711,15 @@ impl VectorStore for RabitQuantizationStorage { debug_assert_eq!(raw_query.ex_bits, self.metadata.num_bits - 1); let query_factor = self.raw_query_factor(dist_q_c, &raw_query.rotated_query, rotated_centroid); + let query_error = + self.raw_query_error(dist_q_c, &raw_query.rotated_query, rotated_centroid); return self.distance_calculator_from_parts( code_dim, Cow::Borrowed(&raw_query.dist_table), Cow::Borrowed(&raw_query.ex_dist_table), raw_query.sum_q, query_factor, + query_error, ); } @@ -1027,6 +1733,7 @@ impl VectorStore for RabitQuantizationStorage { f32_scratch.resize(code_dim + dist_table_len + ex_dist_table_len, 0.0); let query_factor; + let query_error; let sum_q = { let (rotated_qr, remaining) = f32_scratch.split_at_mut(code_dim); let (dist_table, ex_dist_table) = remaining.split_at_mut(dist_table_len); @@ -1055,6 +1762,18 @@ impl VectorStore for RabitQuantizationStorage { self.raw_query_factor(dist_q_c, rotated_qr, None) } }; + query_error = match (self.metadata.query_estimator, residual) { + (RabitQueryEstimator::ResidualQuery, _) => 0.0, + ( + RabitQueryEstimator::RawQuery, + Some(QueryResidual::RabitRawQuery { + rotated_centroid, .. + }), + ) => self.raw_query_error(dist_q_c, rotated_qr, rotated_centroid), + (RabitQueryEstimator::RawQuery, _) => { + self.raw_query_error(dist_q_c, rotated_qr, None) + } + }; build_dist_table_direct_into::(rotated_qr, dist_table); build_ex_dist_table_direct_into(rotated_qr, ex_bits, ex_dist_table); rotated_qr.iter().copied().sum() @@ -1069,6 +1788,7 @@ impl VectorStore for RabitQuantizationStorage { ), sum_q, query_factor, + query_error, ) } @@ -1257,6 +1977,9 @@ impl QuantizerStorage for RabitQuantizationStorage { let scale_factors = batch[SCALE_FACTORS_COLUMN] .as_primitive::() .clone(); + let error_factors = batch + .column_by_name(ERROR_FACTORS_COLUMN) + .map(|factors| factors.as_primitive::().clone()); let ex_bits = rabit_ex_bits(metadata.num_bits)?; let mut ex_codes = None; let mut ex_add_factors = None; @@ -1336,6 +2059,7 @@ impl QuantizerStorage for RabitQuantizationStorage { let mut metadata = metadata.clone(); metadata.packed = true; + let packed_ex_codes = maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits); Ok(Self { metadata, @@ -1345,7 +2069,9 @@ impl QuantizerStorage for RabitQuantizationStorage { codes, add_factors, scale_factors, + error_factors, ex_codes, + packed_ex_codes, ex_add_factors, ex_scale_factors, }) @@ -1413,9 +2139,14 @@ impl QuantizerStorage for RabitQuantizationStorage { let scale_factors = batch[SCALE_FACTORS_COLUMN] .as_primitive::() .clone(); + let error_factors = batch + .column_by_name(ERROR_FACTORS_COLUMN) + .map(|factors| factors.as_primitive::().clone()); let ex_codes = batch .column_by_name(RABIT_EX_CODE_COLUMN) .map(|codes| codes.as_fixed_size_list().clone()); + let packed_ex_codes = + maybe_pack_ex_codes(ex_codes.as_ref(), rabit_ex_bits(self.metadata.num_bits)?); let ex_add_factors = batch .column_by_name(EX_ADD_FACTORS_COLUMN) .map(|factors| factors.as_primitive::().clone()); @@ -1430,7 +2161,9 @@ impl QuantizerStorage for RabitQuantizationStorage { codes, add_factors, scale_factors, + error_factors, ex_codes, + packed_ex_codes, ex_add_factors, ex_scale_factors, row_ids: new_row_ids, @@ -1541,7 +2274,7 @@ fn get_rq_code( #[cfg(test)] mod tests { use super::*; - use std::collections::HashMap; + use std::collections::{BinaryHeap, HashMap}; use arrow_array::{ArrayRef, Float32Array, Float64Array, UInt64Array}; use lance_core::ROW_ID; @@ -1799,6 +2532,12 @@ mod tests { (0..num_rows).map(|v| v as f32 + 0.5), )) as ArrayRef, ), + ( + ERROR_FACTORS_COLUMN, + Arc::new(Float32Array::from_iter_values( + (0..num_rows).map(|v| v as f32 + 0.25), + )) as ArrayRef, + ), ]) .unwrap() } @@ -1836,6 +2575,12 @@ mod tests { (0..num_rows).map(|v| v as f32 + 0.5), )) as ArrayRef, ), + ( + ERROR_FACTORS_COLUMN, + Arc::new(Float32Array::from_iter_values( + (0..num_rows).map(|v| v as f32 + 0.25), + )) as ArrayRef, + ), (RABIT_EX_CODE_COLUMN, Arc::new(ex_codes) as ArrayRef), ( EX_ADD_FACTORS_COLUMN, @@ -1922,6 +2667,223 @@ mod tests { assert_eq!(distances, vec![104.0, 22.0]); } + fn assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits: u8) { + let code_dim = 8usize; + let num_rows = BATCH_SIZE + 1; + let ex_bits = rabit_ex_bits(num_bits).unwrap(); + let identity = Float32Array::from_iter_values( + (0..code_dim) + .flat_map(|row| (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 })), + ); + let rotate_mat = + FixedSizeListArray::try_new_from_values(identity, code_dim as i32).unwrap(); + let metadata = RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Matrix, + code_dim: code_dim as u32, + num_bits, + packed: false, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from_iter_values((0..num_rows).map(|idx| (idx * 13) as u8)), + 1, + ) + .unwrap(); + let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); + let ex_codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from_iter_values( + (0..num_rows * ex_code_len).map(|idx| (idx * 37 % 251) as u8), + ), + ex_code_len as i32, + ) + .unwrap(); + let batch = RecordBatch::try_from_iter(vec![ + ( + ROW_ID, + Arc::new(UInt64Array::from_iter_values(0..num_rows as u64)) as ArrayRef, + ), + (RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![1.0; num_rows])) as ArrayRef, + ), + (RABIT_EX_CODE_COLUMN, Arc::new(ex_codes) as ArrayRef), + ( + EX_ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, + ), + ( + EX_SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![1.0; num_rows])) as ArrayRef, + ), + ]) + .unwrap(); + let storage = + RabitQuantizationStorage::try_from_batch(batch, &metadata, DistanceType::L2, None) + .unwrap(); + assert!(storage.packed_ex_codes.is_some()); + + let query = Arc::new(Float32Array::from(vec![1.0; code_dim])) as ArrayRef; + let calc = storage.dist_calculator(query, 0.0); + let mut distances = Vec::new(); + let mut u16_scratch = Vec::new(); + let mut u8_scratch = Vec::new(); + calc.distance_all_with_scratch(0, &mut distances, &mut u16_scratch, &mut u8_scratch); + + assert_eq!(distances.len(), num_rows); + assert_eq!(u16_scratch.len(), BATCH_SIZE); + assert_eq!( + u8_scratch.len(), + ex_fastscan_code_len(code_dim, ex_bits).unwrap() * 2 * SEGMENT_NUM_CODES + ); + for (id, distance) in distances.iter().take(BATCH_SIZE).enumerate() { + let exact = calc.distance(id as u32); + assert!( + (*distance - exact).abs() < 10.0, + "distance_all fastscan mismatch for id {id}: actual={distance}, exact={exact}" + ); + } + assert_eq!(distances[BATCH_SIZE], calc.distance(BATCH_SIZE as u32)); + } + + #[test] + fn test_raw_query_multi_bit_distance_all_uses_fastscan_for_split_ex_codes() { + for num_bits in [3, 9] { + assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits); + } + } + + #[test] + fn test_raw_query_multi_bit_accumulate_topk_uses_lower_bound_gating() { + let code_dim = 8usize; + let num_rows = BATCH_SIZE + 9; + let num_bits = 3; + let ex_bits = rabit_ex_bits(num_bits).unwrap(); + let identity = Float32Array::from_iter_values( + (0..code_dim) + .flat_map(|row| (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 })), + ); + let rotate_mat = + FixedSizeListArray::try_new_from_values(identity, code_dim as i32).unwrap(); + let metadata = RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Matrix, + code_dim: code_dim as u32, + num_bits, + packed: false, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from_iter_values((0..num_rows).map(|idx| (idx * 19) as u8)), + 1, + ) + .unwrap(); + let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); + let ex_codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from_iter_values( + (0..num_rows * ex_code_len).map(|idx| (idx * 29 % 251) as u8), + ), + ex_code_len as i32, + ) + .unwrap(); + let batch = make_test_batch_with_ex(codes, ex_codes) + .replace_column_by_name( + ERROR_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![1000.0; num_rows])), + ) + .unwrap(); + let storage = + RabitQuantizationStorage::try_from_batch(batch, &metadata, DistanceType::L2, None) + .unwrap(); + let query = Arc::new(Float32Array::from(vec![1.0; code_dim])) as ArrayRef; + let calc = storage.dist_calculator(query, 4.0); + assert!( + calc.raw_query_lower_bound_gating_disabled_reason() + .is_none() + ); + + let k = 5; + let mut binary_ips = Vec::new(); + let mut binary_u16_scratch = Vec::new(); + let mut binary_u8_scratch = Vec::new(); + calc.binary_distances_with_scratch( + num_rows, + rabit_binary_code_bytes(code_dim), + &mut binary_ips, + &mut binary_u16_scratch, + &mut binary_u8_scratch, + ); + let ex_codes = calc.ex_codes.unwrap(); + let ex_add_factors = calc.ex_add_factors.unwrap(); + let ex_scale_factors = calc.ex_scale_factors.unwrap(); + let mut expected = binary_ips + .iter() + .copied() + .enumerate() + .map(|(id, binary_ip)| { + ( + id, + calc.raw_query_multi_bit_exact_distance( + id, + binary_ip, + ex_bits, + ex_code_len, + ex_codes, + ex_add_factors, + ex_scale_factors, + ), + ) + }) + .collect::>(); + expected.sort_by(|left, right| left.1.total_cmp(&right.1)); + expected.truncate(k); + let mut expected = expected + .into_iter() + .map(|(id, dist)| (id as u64, dist)) + .collect::>(); + expected.sort_by(|left, right| left.0.cmp(&right.0)); + + let mut heap = BinaryHeap::with_capacity(k); + let mut distances = Vec::new(); + let mut u16_scratch = Vec::new(); + let mut u8_scratch = Vec::new(); + calc.accumulate_topk_with_scratch( + k, + None, + None, + |id| id as u64, + &mut heap, + &mut distances, + &mut u16_scratch, + &mut u8_scratch, + ); + let mut actual = heap + .into_iter() + .map(|node| (node.id, node.dist.0)) + .collect::>(); + actual.sort_by(|left, right| left.0.cmp(&right.0)); + + assert_eq!(actual.len(), expected.len()); + for ((actual_id, actual_dist), (expected_id, expected_dist)) in + actual.into_iter().zip(expected) + { + assert_eq!(actual_id, expected_id); + assert!( + (actual_dist - expected_dist).abs() < 1e-5, + "actual={actual_dist}, expected={expected_dist}" + ); + } + } + #[test] fn test_raw_query_one_bit_distance_uses_binary_factors_without_ex_columns() { let code_dim = 8usize; @@ -2106,6 +3068,31 @@ mod tests { .value_length(), 64 ); + assert!(stored_batch.column_by_name(ERROR_FACTORS_COLUMN).is_some()); + } + + #[test] + fn test_try_from_batch_accepts_missing_error_factors_for_compatibility() { + let original_codes = make_test_codes(50, 64); + let code_dim = original_codes.value_length() as usize * 8; + let ex_codes = make_test_ex_codes(original_codes.len(), code_dim, 9); + let mut metadata = make_test_metadata(code_dim); + metadata.num_bits = 9; + let batch = make_test_batch_with_ex(original_codes, ex_codes) + .drop_column(ERROR_FACTORS_COLUMN) + .unwrap(); + + let storage = + RabitQuantizationStorage::try_from_batch(batch, &metadata, DistanceType::L2, None) + .unwrap(); + let query = Arc::new(Float32Array::from(vec![1.0; code_dim])) as ArrayRef; + let calc = storage.dist_calculator(query, 4.0); + + assert!(storage.error_factors.is_none()); + assert_eq!( + calc.raw_query_lower_bound_gating_disabled_reason(), + Some("missing_error_factors") + ); } #[test] @@ -2190,5 +3177,11 @@ mod tests { .values()[..5], &[1.5, 2.5, 3.5, 5.5, 6.5] ); + assert_eq!( + &remapped_batch[ERROR_FACTORS_COLUMN] + .as_primitive::() + .values()[..5], + &[0.25, 1.25, 2.25, 4.25, 5.25] + ); } } diff --git a/rust/lance-index/src/vector/distributed/index_merger.rs b/rust/lance-index/src/vector/distributed/index_merger.rs index e003cf52599..4f59c83bfcf 100755 --- a/rust/lance-index/src/vector/distributed/index_merger.rs +++ b/rust/lance-index/src/vector/distributed/index_merger.rs @@ -20,11 +20,12 @@ use std::sync::Arc; use crate::IndexMetadata as IndexMetaSchema; use crate::pb; use crate::vector::bq::storage::{ - RABIT_CODE_COLUMN, RABIT_METADATA_KEY, RabitQuantizationMetadata, pack_codes, - rabit_binary_code_field, rabit_ex_code_field, + RABIT_CODE_COLUMN, RABIT_METADATA_KEY, RabitQuantizationMetadata, RabitQueryEstimator, + pack_codes, rabit_binary_code_field, rabit_ex_code_field, }; use crate::vector::bq::transform::{ - ADD_FACTORS_FIELD, EX_ADD_FACTORS_FIELD, EX_SCALE_FACTORS_FIELD, SCALE_FACTORS_FIELD, + ADD_FACTORS_FIELD, ERROR_FACTORS_FIELD, EX_ADD_FACTORS_FIELD, EX_SCALE_FACTORS_FIELD, + SCALE_FACTORS_FIELD, }; use crate::vector::bq::validate_rq_num_bits; use crate::vector::flat::index::FlatMetadata; @@ -307,6 +308,9 @@ pub async fn init_writer_for_rq( ADD_FACTORS_FIELD.clone(), SCALE_FACTORS_FIELD.clone(), ]; + if rq_meta.query_estimator == RabitQueryEstimator::RawQuery { + fields.push(ERROR_FACTORS_FIELD.clone()); + } if let Some(ex_code_field) = rabit_ex_code_field(rq_meta.rotated_dim(), rq_meta.num_bits)? { fields.push(ex_code_field); fields.push(EX_ADD_FACTORS_FIELD.clone()); @@ -2083,6 +2087,9 @@ mod tests { ADD_FACTORS_FIELD.clone(), SCALE_FACTORS_FIELD.clone(), ]; + if metadata.query_estimator == RabitQueryEstimator::RawQuery { + fields.push(ERROR_FACTORS_FIELD.clone()); + } if let Some(field) = ex_code_field { fields.push(field); fields.push(EX_ADD_FACTORS_FIELD.clone()); @@ -2117,6 +2124,7 @@ mod tests { let mut codes = Vec::with_capacity(total_rows * num_bytes); let mut add_factors = Vec::with_capacity(total_rows); let mut scale_factors = Vec::with_capacity(total_rows); + let mut error_factors = Vec::with_capacity(total_rows); let mut ex_codes = ex_code_bytes.map(|num_bytes| Vec::with_capacity(total_rows * num_bytes)); let mut ex_add_factors = Vec::with_capacity(total_rows); @@ -2132,6 +2140,7 @@ mod tests { } add_factors.push(pid as f32 + row_offset as f32 * 0.1); scale_factors.push(pid as f32 + row_offset as f32 * 0.2); + error_factors.push(pid as f32 + row_offset as f32 * 0.3); if let (Some(ex_codes), Some(ex_code_bytes)) = (ex_codes.as_mut(), ex_code_bytes) { for b in 0..ex_code_bytes { ex_codes.push((17 + pid + row_offset + b) as u8); @@ -2151,6 +2160,9 @@ mod tests { Arc::new(Float32Array::from(add_factors)), Arc::new(Float32Array::from(scale_factors)), ]; + if metadata.query_estimator == RabitQueryEstimator::RawQuery { + columns.push(Arc::new(Float32Array::from(error_factors))); + } if let (Some(ex_codes), Some(ex_code_bytes)) = (ex_codes, ex_code_bytes) { columns.push(Arc::new(FixedSizeListArray::try_new_from_values( UInt8Array::from(ex_codes), @@ -2530,6 +2542,7 @@ mod tests { panic!("RQ ex-code field should be FixedSizeList"); }; assert_eq!(*ex_code_bytes, 6); + assert!(schema.field_with_name(ERROR_FACTORS_FIELD.name()).is_ok()); assert!(schema.field_with_name(EX_ADD_FACTORS_COLUMN).is_ok()); assert!(schema.field_with_name(EX_SCALE_FACTORS_COLUMN).is_ok()); checked_split_columns = true; From fbd2e71b5fdbdad24a1f262faca6100b047e1c2b Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 4 Jun 2026 18:50:17 +0800 Subject: [PATCH 03/14] perf(index): add ivf rq multi-bit raw-query fastscan --- rust/lance-linalg/src/simd/dist_table.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/rust/lance-linalg/src/simd/dist_table.rs b/rust/lance-linalg/src/simd/dist_table.rs index bfc05fc2f26..646adc6bdcd 100644 --- a/rust/lance-linalg/src/simd/dist_table.rs +++ b/rust/lance-linalg/src/simd/dist_table.rs @@ -10,7 +10,7 @@ use std::arch::x86_64::*; use lance_core::utils::cpu::{SIMD_SUPPORT, SimdSupport}; pub const PERM0: [usize; 16] = [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15]; -pub const PERM0_INVERSE: [usize; 16] = [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]; +pub const PERM0_INVERSE: [usize; 16] = [0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15]; pub const BATCH_SIZE: usize = 32; // This function is used to sum the distance table for 4-bit codes. @@ -270,6 +270,13 @@ unsafe extern "C" { mod tests { use super::*; + #[test] + fn test_perm0_inverse_matches_perm0() { + for (idx, &value) in PERM0.iter().enumerate() { + assert_eq!(PERM0_INVERSE[value], idx); + } + } + #[test] fn test_sum_4bit_dist_table_basic() { // we have 32 vectors From cdc3375935a664e74edf53f7ab0c69b6870fe11f Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 4 Jun 2026 22:59:56 +0800 Subject: [PATCH 04/14] perf(index): gate ivf rq ex-code boosting --- rust/lance-index/src/vector/flat/index.rs | 25 +++----------- rust/lance-index/src/vector/storage.rs | 40 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/rust/lance-index/src/vector/flat/index.rs b/rust/lance-index/src/vector/flat/index.rs index 6ebab23688f..1fb2f9bdf3a 100644 --- a/rust/lance-index/src/vector/flat/index.rs +++ b/rust/lance-index/src/vector/flat/index.rs @@ -284,31 +284,16 @@ impl IvfSubIndex for FlatIndex { match prefilter.is_empty() { true => { - dist_calc.distance_all_with_scratch( + dist_calc.accumulate_topk_with_scratch( k, + params.lower_bound, + params.upper_bound, + |id| storage.row_id(id), + res, &mut scratch.distances, &mut scratch.u16, &mut scratch.u8, ); - let dists = scratch.distances.iter().copied(); - - if is_range_query { - let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into(); - let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into(); - - for (&row_id, dist) in row_ids.zip(dists) { - let dist = dist.into(); - if dist < lower_bound || dist >= upper_bound { - continue; - } - push_candidate_global(res, k, row_id, dist, &mut max_dist); - } - } else { - for (&row_id, dist) in row_ids.zip(dists) { - let dist = dist.into(); - push_candidate_global(res, k, row_id, dist, &mut max_dist); - } - } } false => { let row_addr_mask = prefilter.mask(); diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index 36974180d41..8e07f45f4d4 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -18,6 +18,7 @@ use lance_linalg::distance::DistanceType; use prost::Message; use std::{ any::Any, + collections::BinaryHeap, mem::size_of, ops::{Deref, DerefMut}, sync::Arc, @@ -63,6 +64,45 @@ pub trait DistCalculator { } fn prefetch(&self, _id: u32) {} + + #[allow(clippy::too_many_arguments)] + fn accumulate_topk_with_scratch( + &self, + k: usize, + lower_bound: Option, + upper_bound: Option, + row_id: impl Fn(u32) -> u64, + res: &mut BinaryHeap>, + dists: &mut Vec, + u16_scratch: &mut Vec, + u8_scratch: &mut Vec, + ) { + if k == 0 { + return; + } + + self.distance_all_with_scratch(k, dists, u16_scratch, u8_scratch); + let lower_bound = lower_bound.unwrap_or(f32::MIN).into(); + let upper_bound = upper_bound.unwrap_or(f32::MAX).into(); + let mut max_dist = res.peek().map(|node| node.dist); + + for (id, dist) in dists.iter().copied().enumerate() { + let dist = OrderedFloat(dist); + if dist < lower_bound || dist >= upper_bound { + continue; + } + if res.len() < k { + res.push(OrderedNode::new(row_id(id as u32), dist)); + if res.len() == k { + max_dist = res.peek().map(|node| node.dist); + } + } else if max_dist.is_some_and(|max_dist| max_dist > dist) { + res.pop(); + res.push(OrderedNode::new(row_id(id as u32), dist)); + max_dist = res.peek().map(|node| node.dist); + } + } + } } pub const STORAGE_METADATA_KEY: &str = "storage_metadata"; From 761941e069f178ed73083c719a2b47a8fb52436b Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 5 Jun 2026 00:52:10 +0800 Subject: [PATCH 05/14] perf(index): apply ivf rq gating to prefilter --- rust/lance-index/src/vector/flat/index.rs | 59 +++++------------------ rust/lance-index/src/vector/storage.rs | 42 ++++++++++++++++ 2 files changed, 53 insertions(+), 48 deletions(-) diff --git a/rust/lance-index/src/vector/flat/index.rs b/rust/lance-index/src/vector/flat/index.rs index 1fb2f9bdf3a..19f302027c1 100644 --- a/rust/lance-index/src/vector/flat/index.rs +++ b/rust/lance-index/src/vector/flat/index.rs @@ -48,29 +48,6 @@ fn push_candidate_local( } } -#[inline(always)] -fn push_candidate_global( - res: &mut BinaryHeap>, - k: usize, - row_id: u64, - dist: OrderedFloat, - max_dist: &mut Option, -) { - if k == 0 { - return; - } - if res.len() < k { - res.push(OrderedNode::new(row_id, dist)); - if res.len() == k { - *max_dist = res.peek().map(|node| node.dist); - } - } else if max_dist.is_some_and(|max_dist| max_dist > dist) { - res.pop(); - res.push(OrderedNode::new(row_id, dist)); - *max_dist = res.peek().map(|node| node.dist); - } -} - /// A Flat index is any index that stores no metadata, and /// during query, it simply scans over the storage and returns the top k results #[derive(Debug, Clone, Default, DeepSizeOf)] @@ -271,7 +248,6 @@ impl IvfSubIndex for FlatIndex { scratch: &mut QueryScratch, metrics: &dyn MetricsCollector, ) -> Result<()> { - let is_range_query = params.lower_bound.is_some() || params.upper_bound.is_some(); let row_ids = storage.row_ids(); let dist_calc = storage.dist_calculator_with_scratch( query, @@ -279,7 +255,6 @@ impl IvfSubIndex for FlatIndex { residual, &mut scratch.query_f32, ); - let mut max_dist = res.peek().map(|node| node.dist); metrics.record_comparisons(storage.len()); match prefilter.is_empty() { @@ -297,29 +272,17 @@ impl IvfSubIndex for FlatIndex { } false => { let row_addr_mask = prefilter.mask(); - if is_range_query { - let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into(); - let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into(); - for (id, &row_addr) in row_ids.enumerate() { - if !row_addr_mask.selected(row_addr) { - continue; - } - let dist = dist_calc.distance(id as u32).into(); - if dist < lower_bound || dist >= upper_bound { - continue; - } - - push_candidate_global(res, k, row_addr, dist, &mut max_dist); - } - } else { - for (id, &row_addr) in row_ids.enumerate() { - if !row_addr_mask.selected(row_addr) { - continue; - } - let dist = dist_calc.distance(id as u32).into(); - push_candidate_global(res, k, row_addr, dist, &mut max_dist); - } - } + dist_calc.accumulate_filtered_topk_with_scratch( + k, + params.lower_bound, + params.upper_bound, + row_ids.enumerate().map(|(id, &row_id)| (id as u32, row_id)), + |row_id| row_addr_mask.selected(row_id), + res, + &mut scratch.distances, + &mut scratch.u16, + &mut scratch.u8, + ); } }; Ok(()) diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index 8e07f45f4d4..8caac81e00d 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -103,6 +103,48 @@ pub trait DistCalculator { } } } + + #[allow(clippy::too_many_arguments)] + fn accumulate_filtered_topk_with_scratch( + &self, + k: usize, + lower_bound: Option, + upper_bound: Option, + row_ids: impl Iterator, + accept_row: impl Fn(u64) -> bool, + res: &mut BinaryHeap>, + _dists: &mut Vec, + _u16_scratch: &mut Vec, + _u8_scratch: &mut Vec, + ) { + if k == 0 { + return; + } + + let lower_bound = lower_bound.unwrap_or(f32::MIN).into(); + let upper_bound = upper_bound.unwrap_or(f32::MAX).into(); + let mut max_dist = res.peek().map(|node| node.dist); + + for (id, row_id) in row_ids { + if !accept_row(row_id) { + continue; + } + let dist = OrderedFloat(self.distance(id)); + if dist < lower_bound || dist >= upper_bound { + continue; + } + if res.len() < k { + res.push(OrderedNode::new(row_id, dist)); + if res.len() == k { + max_dist = res.peek().map(|node| node.dist); + } + } else if max_dist.is_some_and(|max_dist| max_dist > dist) { + res.pop(); + res.push(OrderedNode::new(row_id, dist)); + max_dist = res.peek().map(|node| node.dist); + } + } + } } pub const STORAGE_METADATA_KEY: &str = "storage_metadata"; From 325d7ba4fe1710d80aed7c9a56b012a0127ef253 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 5 Jun 2026 01:24:43 +0800 Subject: [PATCH 06/14] fix(index): preserve ivf rq error factors --- rust/lance-index/src/vector/bq/builder.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index ea3d9ed2667..d7815352d58 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -17,8 +17,8 @@ use rand_distr::Distribution; use rayon::prelude::*; use crate::vector::bq::storage::{ - RABIT_CODE_COLUMN, RABIT_METADATA_KEY, RabitQuantizationMetadata, RabitQuantizationStorage, - RabitQueryEstimator, rabit_binary_code_field, rabit_ex_code_field, + RABIT_CODE_COLUMN, RABIT_EX_CODE_COLUMN, RABIT_METADATA_KEY, RabitQuantizationMetadata, + RabitQuantizationStorage, RabitQueryEstimator, rabit_binary_code_field, rabit_ex_code_field, }; use crate::vector::bq::transform::{ ADD_FACTORS_FIELD, ERROR_FACTORS_FIELD, EX_ADD_FACTORS_FIELD, EX_SCALE_FACTORS_FIELD, From a167031b84436e455e7fad009068c5d4d6abe949 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 5 Jun 2026 01:46:37 +0800 Subject: [PATCH 07/14] chore(index): clean ivf rq test import --- rust/lance-index/src/vector/bq/builder.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index d7815352d58..ea3d9ed2667 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -17,8 +17,8 @@ use rand_distr::Distribution; use rayon::prelude::*; use crate::vector::bq::storage::{ - RABIT_CODE_COLUMN, RABIT_EX_CODE_COLUMN, RABIT_METADATA_KEY, RabitQuantizationMetadata, - RabitQuantizationStorage, RabitQueryEstimator, rabit_binary_code_field, rabit_ex_code_field, + RABIT_CODE_COLUMN, RABIT_METADATA_KEY, RabitQuantizationMetadata, RabitQuantizationStorage, + RabitQueryEstimator, rabit_binary_code_field, rabit_ex_code_field, }; use crate::vector::bq::transform::{ ADD_FACTORS_FIELD, ERROR_FACTORS_FIELD, EX_ADD_FACTORS_FIELD, EX_SCALE_FACTORS_FIELD, From 763ecc09c53fba9ce9b085aae67e6b68f378b90b Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Jun 2026 13:36:31 +0800 Subject: [PATCH 08/14] perf(index): optimize ivf rq raw-query search --- rust/lance-index/src/vector/bq/storage.rs | 188 +++++++++++++++--- rust/lance/src/index/vector.rs | 10 +- .../src/index/vector/ivf/partition_serde.rs | 5 +- rust/lance/src/index/vector/ivf/v2.rs | 173 +++++++++++----- 4 files changed, 285 insertions(+), 91 deletions(-) diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index 460552fb1bb..50346d65510 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -531,6 +531,25 @@ impl RabitQuantizationStorage { } } + fn uses_raw_query_lower_bound_gating(&self) -> bool { + self.metadata.query_estimator == RabitQueryEstimator::RawQuery + && self.metadata.num_bits > 1 + && self.error_factors.is_some() + } + + fn raw_query_error_for_gating( + &self, + dist_q_c: f32, + rotated_query: &[f32], + rotated_centroid: Option<&[f32]>, + ) -> f32 { + if self.uses_raw_query_lower_bound_gating() { + self.raw_query_error(dist_q_c, rotated_query, rotated_centroid) + } else { + 0.0 + } + } + fn distance_calculator_from_parts<'a>( &'a self, dim: usize, @@ -862,6 +881,79 @@ impl<'a> RabitDistCalculator<'a> { simd_len } + #[inline] + fn binary_distance_factor_params(&self) -> (f32, f32) { + match self.query_estimator { + RabitQueryEstimator::ResidualQuery => (2.0 / self.sqrt_d, -self.sum_q / self.sqrt_d), + RabitQueryEstimator::RawQuery => (1.0, -0.5 * self.sum_q), + } + } + + #[allow(clippy::uninit_vec)] + fn one_bit_distances_with_scratch( + &self, + n: usize, + code_len: usize, + dists: &mut Vec, + quantized_dists: &mut Vec, + quantized_dists_table: &mut Vec, + ) { + let (qmin, qmax) = quantize_dist_table_into(&self.dist_table, quantized_dists_table); + let remainder = n % BATCH_SIZE; + let simd_len = n - remainder; + quantized_dists.clear(); + quantized_dists.reserve(simd_len); + // SAFETY: sum_4bit_dist_table overwrites each element in the SIMD batch range. + unsafe { + quantized_dists.set_len(simd_len); + } + simd::dist_table::sum_4bit_dist_table( + simd_len, + code_len, + self.codes, + quantized_dists_table, + quantized_dists, + ); + + let range = (qmax - qmin) / 255.0; + let num_tables = quantized_dists_table.len() / SEGMENT_NUM_CODES; + let sum_min = num_tables as f32 * qmin; + let (binary_distance_multiplier, binary_distance_offset) = + self.binary_distance_factor_params(); + dists.clear(); + dists.reserve(n); + // SAFETY: the SIMD section below writes [0, simd_len), and the + // remainder section writes [simd_len, n). + unsafe { + dists.set_len(n); + } + let (simd_dists, remainder_dists) = dists.split_at_mut(simd_len); + simd_dists + .iter_mut() + .zip(quantized_dists.iter()) + .enumerate() + .for_each(|(id, (dist, q_dist))| { + let binary_dist = (*q_dist as f32) * range + sum_min; + *dist = (binary_dist * binary_distance_multiplier + binary_distance_offset) + * self.scale_factors[id] + + self.add_factors[id] + + self.query_factor; + }); + + remainder_dists + .iter_mut() + .enumerate() + .for_each(|(offset, dist)| { + let id = simd_len + offset; + let binary_dist = + compute_single_rq_distance(self.codes, id, n, code_len, &self.dist_table); + *dist = (binary_dist * binary_distance_multiplier + binary_distance_offset) + * self.scale_factors[id] + + self.add_factors[id] + + self.query_factor; + }); + } + #[allow(clippy::uninit_vec)] fn apply_raw_query_multi_bit_distances( &self, @@ -1442,6 +1534,17 @@ impl DistCalculator for RabitDistCalculator<'_> { return; } + if self.query_estimator == RabitQueryEstimator::ResidualQuery || self.num_bits == 1 { + self.one_bit_distances_with_scratch( + n, + code_len, + dists, + quantized_dists, + quantized_dists_table, + ); + return; + } + let simd_len = self.binary_distances_with_scratch( n, code_len, @@ -1450,33 +1553,12 @@ impl DistCalculator for RabitDistCalculator<'_> { quantized_dists_table, ); - if self.query_estimator == RabitQueryEstimator::RawQuery && self.num_bits > 1 { - self.apply_raw_query_multi_bit_distances( - simd_len, - dists, - quantized_dists, - quantized_dists_table, - ); - return; - } - - dists - .iter_mut() - .enumerate() - .for_each(|(id, dist)| match self.query_estimator { - RabitQueryEstimator::ResidualQuery => { - let dist_vq_qr = (2.0 * *dist - self.sum_q) / self.sqrt_d; - *dist = dist_vq_qr * self.scale_factors[id] - + self.add_factors[id] - + self.query_factor; - } - RabitQueryEstimator::RawQuery => { - let binary_dot = *dist - 0.5 * self.sum_q; - *dist = binary_dot * self.scale_factors[id] - + self.add_factors[id] - + self.query_factor; - } - }); + self.apply_raw_query_multi_bit_distances( + simd_len, + dists, + quantized_dists, + quantized_dists_table, + ); } #[allow(clippy::too_many_arguments)] @@ -1675,7 +1757,9 @@ impl VectorStore for RabitQuantizationStorage { }; let query_error = match self.metadata.query_estimator { RabitQueryEstimator::ResidualQuery => 0.0, - RabitQueryEstimator::RawQuery => self.raw_query_error(dist_q_c, &rotated_qr, None), + RabitQueryEstimator::RawQuery => { + self.raw_query_error_for_gating(dist_q_c, &rotated_qr, None) + } }; let sum_q = rotated_qr.into_iter().sum(); @@ -1711,8 +1795,11 @@ impl VectorStore for RabitQuantizationStorage { debug_assert_eq!(raw_query.ex_bits, self.metadata.num_bits - 1); let query_factor = self.raw_query_factor(dist_q_c, &raw_query.rotated_query, rotated_centroid); - let query_error = - self.raw_query_error(dist_q_c, &raw_query.rotated_query, rotated_centroid); + let query_error = self.raw_query_error_for_gating( + dist_q_c, + &raw_query.rotated_query, + rotated_centroid, + ); return self.distance_calculator_from_parts( code_dim, Cow::Borrowed(&raw_query.dist_table), @@ -1769,9 +1856,9 @@ impl VectorStore for RabitQuantizationStorage { Some(QueryResidual::RabitRawQuery { rotated_centroid, .. }), - ) => self.raw_query_error(dist_q_c, rotated_qr, rotated_centroid), + ) => self.raw_query_error_for_gating(dist_q_c, rotated_qr, rotated_centroid), (RabitQueryEstimator::RawQuery, _) => { - self.raw_query_error(dist_q_c, rotated_qr, None) + self.raw_query_error_for_gating(dist_q_c, rotated_qr, None) } }; build_dist_table_direct_into::(rotated_qr, dist_table); @@ -1958,6 +2045,10 @@ impl QuantizerStorage for RabitQuantizationStorage { distance_type: DistanceType, _fri: Option>, ) -> Result { + let distance_type = match (metadata.query_estimator, distance_type) { + (RabitQueryEstimator::RawQuery, DistanceType::Cosine) => DistanceType::L2, + _ => distance_type, + }; validate_rq_num_bits(metadata.num_bits)?; let row_ids = batch[ROW_ID].as_primitive::().clone(); let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone(); @@ -3003,6 +3094,39 @@ mod tests { assert_codes_eq(stored_codes, &expected_codes); } + #[test] + fn test_try_from_batch_uses_l2_for_cosine() { + let original_codes = make_test_codes(50, 64); + let metadata = make_test_metadata(original_codes.value_length() as usize * 8); + + let storage = RabitQuantizationStorage::try_from_batch( + make_test_batch(original_codes), + &metadata, + DistanceType::Cosine, + None, + ) + .unwrap(); + + assert_eq!(storage.distance_type(), DistanceType::L2); + } + + #[test] + fn test_try_from_batch_keeps_cosine_for_legacy_residual_query() { + let original_codes = make_test_codes(50, 64); + let mut metadata = make_test_metadata(original_codes.value_length() as usize * 8); + metadata.query_estimator = RabitQueryEstimator::ResidualQuery; + + let storage = RabitQuantizationStorage::try_from_batch( + make_test_batch(original_codes), + &metadata, + DistanceType::Cosine, + None, + ) + .unwrap(); + + assert_eq!(storage.distance_type(), DistanceType::Cosine); + } + #[test] fn test_try_from_batch_requires_ex_columns_for_multi_bit_rq() { let original_codes = make_test_codes(50, 64); diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index c7d405ec7b1..ba4e1f77e91 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -29,9 +29,7 @@ use lance_index::metrics::NoOpMetricsCollector; use lance_index::optimize::OptimizeOptions; use lance_index::progress::{IndexBuildProgress, noop_progress}; use lance_index::vector::bq::builder::RabitQuantizer; -use lance_index::vector::bq::{ - RABIT_BINARY_NUM_BITS, RQBuildParams, RQRotationType, validate_supported_rq_num_bits, -}; +use lance_index::vector::bq::{RQBuildParams, RQRotationType, validate_supported_rq_num_bits}; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::ivf::builder::recommended_num_partitions; @@ -555,12 +553,6 @@ async fn prepare_vector_segment_build( ))); }; validate_supported_rq_num_bits(rq_params.num_bits)?; - if rq_params.num_bits > RABIT_BINARY_NUM_BITS && params.metric_type == DistanceType::Cosine - { - return Err(Error::not_supported( - "IVF_RQ num_bits>1 cosine index creation is not supported until raw-query cosine search support is implemented", - )); - } } let num_rows = dataset.count_rows(None).await?; diff --git a/rust/lance/src/index/vector/ivf/partition_serde.rs b/rust/lance/src/index/vector/ivf/partition_serde.rs index 7ae77abaa06..2cf3719e12a 100644 --- a/rust/lance/src/index/vector/ivf/partition_serde.rs +++ b/rust/lance/src/index/vector/ivf/partition_serde.rs @@ -1098,7 +1098,9 @@ mod tests { #[test] fn test_ivf_index_state_roundtrip() { - use crate::index::vector::ivf::v2::{IvfIndexState, IvfStateEntryBox}; + use crate::index::vector::ivf::v2::{ + IvfIndexState, IvfStateEntryBox, empty_rabit_search_cache_cell, + }; use lance_index::vector::flat::index::FlatQuantizer; use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::quantizer::QuantizationType; @@ -1123,6 +1125,7 @@ mod tests { cache_key_prefix: "prefix/".to_string(), index_file_size: 1024, aux_file_size: 512, + rq_search_cache: empty_rabit_search_cache_cell(), }; let entry = IvfStateEntryBox(Arc::new(state)); diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 7901390f2b2..c415a841cff 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -8,7 +8,7 @@ use std::marker::PhantomData; use std::{ any::Any, collections::{BinaryHeap, HashMap}, - sync::Arc, + sync::{Arc, Mutex}, }; use crate::index::vector::{IndexFileVersion, builder::index_type_string}; @@ -38,6 +38,7 @@ use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::{LocalMetricsCollector, MetricsCollector, NoOpMetricsCollector}; use lance_index::vector::VectorIndexCacheEntry; use lance_index::vector::bq::builder::RabitQuantizer; +use lance_index::vector::bq::rabit_ex_bits; use lance_index::vector::bq::storage::RabitQueryEstimator; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::graph::OrderedNode; @@ -78,6 +79,8 @@ use tracing::{info, instrument}; use super::{IvfIndexPartitionStatistics, IvfIndexStatistics, maybe_centroids_for_stats}; +pub(crate) type RabitSearchCacheCell = Arc>>>>; + /// Serializable state of an IVF index, sufficient to reconstruct the index /// without re-reading global buffers from object storage. /// @@ -109,6 +112,8 @@ pub(crate) struct IvfIndexState { /// when reconstructing from cache. pub(crate) index_file_size: u64, pub(crate) aux_file_size: u64, + /// Runtime-only cache, intentionally excluded from the CacheCodec wire format. + pub(crate) rq_search_cache: RabitSearchCacheCell, } struct PreparedPartitionSearch { @@ -116,18 +121,48 @@ struct PreparedPartitionSearch { pre_filter: Arc, partition_id: usize, partition_centroid: Option, - rotated_partition_centroid: Option>, + rq_search_cache: Option>, raw_query_context: Option>, part_entry: Arc, _marker: PhantomData<(S, Q)>, } #[derive(Debug)] -struct RabitSearchCache { +pub(crate) struct RabitSearchCache { rotated_centroids: Vec, code_dim: usize, } +pub(crate) fn empty_rabit_search_cache_cell() -> RabitSearchCacheCell { + Arc::new(Mutex::new(None)) +} + +fn rabit_search_cache_cell(cache: Option>) -> RabitSearchCacheCell { + Arc::new(Mutex::new(Some(cache))) +} + +fn rotated_partition_centroid_slice( + cache: Option<&RabitSearchCache>, + partition_id: usize, +) -> Option<&[f32]> { + let cache = cache?; + let start = partition_id.checked_mul(cache.code_dim)?; + let end = start.checked_add(cache.code_dim)?; + cache.rotated_centroids.get(start..end) +} + +fn rabit_ex_dist_table_len(dim: usize, num_bits: u8) -> usize { + rabit_ex_bits(num_bits) + .map(|ex_bits| { + if ex_bits == 0 { + 0 + } else { + dim * (1usize << usize::from(ex_bits)) + } + }) + .unwrap_or(dim * 256) +} + impl DeepSizeOf for IvfIndexState { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { self.index_file_path.deep_size_of_children(context) @@ -137,6 +172,13 @@ impl DeepSizeOf for IvfIndexState { + self.sub_index_metadata.deep_size_of_children(context) + self.metadata.deep_size_of_children(context) + self.cache_key_prefix.deep_size_of_children(context) + + self + .rq_search_cache + .lock() + .ok() + .and_then(|cache| cache.as_ref().and_then(|cache| cache.as_ref().cloned())) + .map(|cache| cache.rotated_centroids.len() * std::mem::size_of::()) + .unwrap_or_default() } } @@ -268,6 +310,7 @@ impl CacheCodecImpl for IvfStateEntryBox { cache_key_prefix: header.cache_key_prefix, index_file_size: header.index_file_size, aux_file_size: header.aux_file_size, + rq_search_cache: empty_rabit_search_cache_cell(), }))) } @@ -567,22 +610,6 @@ impl DeepSizeOf for IVFIndex { } impl IVFIndex { - fn ensure_search_supported(&self) -> Result<()> { - if Q::quantization_type() == QuantizationType::Rabit { - let metadata = serde_json::to_value(self.storage.metadata())?; - let num_bits = metadata - .get("num_bits") - .and_then(|value| value.as_u64()) - .unwrap_or(1); - if num_bits > 1 && self.distance_type == DistanceType::Cosine { - return Err(Error::not_supported( - "IVF_RQ num_bits>1 cosine search is not supported until raw-query cosine factors are implemented", - )); - } - } - Ok(()) - } - fn use_query_residual( storage: &IvfQuantizationStorage, distance_type: DistanceType, @@ -618,14 +645,20 @@ impl IVFIndex { }))) } - fn rotated_partition_centroid(&self, partition_id: usize) -> Option> { - let cache = self.rq_search_cache.as_ref()?; - let start = partition_id.checked_mul(cache.code_dim)?; - let end = start.checked_add(cache.code_dim)?; - cache - .rotated_centroids - .get(start..end) - .map(|centroid| centroid.to_vec()) + fn rq_search_cache_from_state( + state: &IvfIndexState, + storage: &IvfQuantizationStorage, + ) -> Result>> { + let mut cache = state + .rq_search_cache + .lock() + .map_err(|_| Error::internal("RQ search cache lock was poisoned".to_string()))?; + if let Some(cache) = cache.as_ref() { + return Ok(cache.clone()); + } + let built = Self::build_rq_search_cache(&state.ivf, storage)?; + *cache = Some(built.clone()); + Ok(built) } fn prepare_rq_raw_query_context( @@ -664,7 +697,7 @@ impl IVFIndex { pre_filter, partition_id, partition_centroid: self.ivf.centroid(partition_id), - rotated_partition_centroid: self.rotated_partition_centroid(partition_id), + rq_search_cache: self.rq_search_cache.clone(), raw_query_context, part_entry, _marker: PhantomData, @@ -685,7 +718,7 @@ impl IVFIndex { pre_filter, partition_id, partition_centroid: self.ivf.centroid(partition_id), - rotated_partition_centroid: self.rotated_partition_centroid(partition_id), + rq_search_cache: self.rq_search_cache.clone(), raw_query_context, part_entry, _marker: PhantomData, @@ -704,17 +737,19 @@ impl IVFIndex { pre_filter, partition_id, partition_centroid, - rotated_partition_centroid, + rq_search_cache, raw_query_context, part_entry, _marker: _, } = prepared; + let rotated_partition_centroid = + rotated_partition_centroid_slice(rq_search_cache.as_deref(), partition_id); let residual = Self::query_context_for_scratch( use_query_residual, use_residual_scratch, partition_id, partition_centroid.as_ref(), - rotated_partition_centroid.as_deref(), + rotated_partition_centroid, raw_query_context.as_deref(), )?; let query = Self::preprocess_partition_query_owned( @@ -760,17 +795,19 @@ impl IVFIndex { pre_filter, partition_id, partition_centroid, - rotated_partition_centroid, + rq_search_cache, raw_query_context, part_entry, _marker: _, } = prepared; + let rotated_partition_centroid = + rotated_partition_centroid_slice(rq_search_cache.as_deref(), partition_id); let residual = Self::query_context_for_scratch( use_query_residual, use_residual_scratch, partition_id, partition_centroid.as_ref(), - rotated_partition_centroid.as_deref(), + rotated_partition_centroid, raw_query_context.as_deref(), )?; let query = Self::preprocess_partition_query_owned( @@ -874,19 +911,25 @@ impl IVFIndex { Ok(query) } - fn query_scratch_capacity(ivf: &IvfModel) -> QueryScratchCapacity { + fn query_scratch_capacity( + ivf: &IvfModel, + storage: &IvfQuantizationStorage, + ) -> QueryScratchCapacity { if Q::quantization_type() != QuantizationType::Rabit { return QueryScratchCapacity::default(); } let dim = ivf.dimension(); let dist_table_len = dim * 4; - let max_ex_dist_table_len = dim * 256; + let ex_dist_table_len = match storage.quantizer() { + Ok(Quantizer::Rabit(rq)) => rabit_ex_dist_table_len(dim, rq.metadata_ref().num_bits), + _ => dim * 256, + }; let max_partition_len = ivf.lengths.iter().copied().max().unwrap_or_default() as usize; QueryScratchCapacity::new( max_partition_len, - dim + dist_table_len + max_ex_dist_table_len, + dim + dist_table_len + ex_dist_table_len, max_partition_len, dist_table_len, ) @@ -901,10 +944,10 @@ impl IVFIndex { .unwrap_or(false) } - fn query_scratch_pool(ivf: &IvfModel) -> QueryScratchPool { + fn query_scratch_pool(ivf: &IvfModel, storage: &IvfQuantizationStorage) -> QueryScratchPool { QueryScratchPool::with_capacity( get_num_compute_intensive_cpus(), - Self::query_scratch_capacity(ivf), + Self::query_scratch_capacity(ivf, storage), ) } @@ -1011,7 +1054,7 @@ impl IVFIndex { ) .await; - let scratch_pool = Arc::new(Self::query_scratch_pool(&ivf)); + let scratch_pool = Arc::new(Self::query_scratch_pool(&ivf, &storage)); let use_query_residual = Self::use_query_residual(&storage, distance_type); let use_residual_scratch = Self::use_residual_scratch(&ivf, use_query_residual); let rq_search_cache = Self::build_rq_search_cache(&ivf, &storage)?; @@ -1048,11 +1091,11 @@ impl IVFIndex { distance_type: DistanceType, index_cache: LanceCache, io_parallelism: usize, + rq_search_cache: Option>, ) -> Self { - let scratch_pool = Arc::new(Self::query_scratch_pool(&ivf)); + let scratch_pool = Arc::new(Self::query_scratch_pool(&ivf, &storage)); let use_query_residual = Self::use_query_residual(&storage, distance_type); let use_residual_scratch = Self::use_residual_scratch(&ivf, use_query_residual); - let rq_search_cache = Self::build_rq_search_cache(&ivf, &storage).unwrap_or(None); Self { uri, index_path, @@ -1179,6 +1222,7 @@ impl IVFIndex { cache_key_prefix: self.index_cache.prefix().to_string(), index_file_size: self.reader.metadata().file_size(), aux_file_size: self.storage.reader().metadata().file_size(), + rq_search_cache: rabit_search_cache_cell(self.rq_search_cache.clone()), })) } } @@ -1337,12 +1381,11 @@ impl VectorIndex for IVFInd pre_filter: Arc, metrics: &dyn MetricsCollector, ) -> Result { - self.ensure_search_supported()?; let part_entry = self.load_partition(partition_id, true, metrics).await?; pre_filter.wait_for_ready().await?; let partition_centroid = self.ivf.centroid(partition_id); - let rotated_partition_centroid = self.rotated_partition_centroid(partition_id); + let rq_search_cache = self.rq_search_cache.clone(); let raw_query_context = self.prepare_rq_raw_query_context(&query.key)?; let query = Self::preprocess_partition_query( self.use_query_residual, @@ -1365,12 +1408,14 @@ impl VectorIndex for IVFInd .ok_or(Error::internal( "failed to downcast partition entry".to_string(), ))?; + let rotated_partition_centroid = + rotated_partition_centroid_slice(rq_search_cache.as_deref(), partition_id); let residual = Self::query_context_for_scratch( use_query_residual, use_residual_scratch, partition_id, partition_centroid.as_ref(), - rotated_partition_centroid.as_deref(), + rotated_partition_centroid, raw_query_context.as_deref(), )?; let batch = scratch_pool.with_scratch(|scratch| { @@ -1401,7 +1446,6 @@ impl VectorIndex for IVFInd pre_filter: Arc, metrics: &dyn MetricsCollector, ) -> Result { - self.ensure_search_supported()?; let raw_query_context = self.prepare_rq_raw_query_context(&query.key)?; Ok(Box::new( self.prepare_partition(partition_id, query, pre_filter, metrics, raw_query_context) @@ -1414,7 +1458,6 @@ impl VectorIndex for IVFInd prepared: PreparedPartitionSearchHandle, metrics: &dyn MetricsCollector, ) -> Result { - self.ensure_search_supported()?; let prepared = prepared .downcast::>() .map_err(|_| Error::internal("failed to downcast prepared partition search"))?; @@ -1453,7 +1496,6 @@ impl VectorIndex for IVFInd control: Option>, metrics: Arc, ) -> Result { - self.ensure_search_supported()?; if partitions.len() != q_c_dists.len() { return Err(Error::invalid_input(format!( "partition count {} does not match centroid distance count {}", @@ -1796,6 +1838,7 @@ async fn reconstruct_typed( state.distance_type, None, ); + let rq_search_cache = IVFIndex::::rq_search_cache_from_state(state, &storage)?; let index = IVFIndex::::from_cached_state( to_local_path(&index_path), @@ -1808,6 +1851,7 @@ async fn reconstruct_typed( state.distance_type, index_cache, io_parallelism, + rq_search_cache, ); Ok(Arc::new(index)) } @@ -1889,6 +1933,32 @@ mod tests { lance_testing::define_stage_event_progress!(RecordingProgress, IndexBuildProgress, Result<()>); + #[test] + fn test_rotated_partition_centroid_slice_borrows_cache() { + let cache = super::RabitSearchCache { + rotated_centroids: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + code_dim: 2, + }; + + let centroid = super::rotated_partition_centroid_slice(Some(&cache), 1).unwrap(); + + assert_eq!(centroid, &[3.0, 4.0]); + assert_eq!(centroid.as_ptr(), cache.rotated_centroids[2..].as_ptr()); + assert!(super::rotated_partition_centroid_slice(Some(&cache), 3).is_none()); + assert!(super::rotated_partition_centroid_slice(None, 0).is_none()); + } + + #[test] + fn test_rabit_ex_dist_table_len_uses_num_bits() { + let dim = 960; + + assert_eq!(super::rabit_ex_dist_table_len(dim, 1), 0); + assert_eq!(super::rabit_ex_dist_table_len(dim, 3), dim * 4); + assert_eq!(super::rabit_ex_dist_table_len(dim, 5), dim * 16); + assert_eq!(super::rabit_ex_dist_table_len(dim, 7), dim * 64); + assert_eq!(super::rabit_ex_dist_table_len(dim, 9), dim * 256); + } + async fn generate_test_dataset( test_uri: &str, range: Range, @@ -4131,15 +4201,20 @@ mod tests { test_remap(params.clone(), nlist, recall_requirement).await; } + #[rstest] + #[case::l2(DistanceType::L2)] + #[case::cosine(DistanceType::Cosine)] #[tokio::test] - async fn test_build_ivf_rq_multi_bit_persists_split_codes_and_searches() { + async fn test_build_ivf_rq_multi_bit_persists_split_codes_and_searches( + #[case] distance_type: DistanceType, + ) { let test_dir = TempStrDir::default(); let test_uri = test_dir.as_str(); let (mut dataset, vectors) = generate_test_dataset::(test_uri, 0.0..1.0).await; let ivf_params = IvfBuildParams::new(4); let rq_params = RQBuildParams::with_rotation_type(9, RQRotationType::Fast); - let params = VectorIndexParams::with_ivf_rq_params(DistanceType::L2, ivf_params, rq_params); + let params = VectorIndexParams::with_ivf_rq_params(distance_type, ivf_params, rq_params); dataset .create_index(&["vector"], IndexType::Vector, None, ¶ms, true) .await From 42b122a38b97851820de71ce1a5bf9ae72d284c2 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Jun 2026 16:45:43 +0800 Subject: [PATCH 09/14] test(index): update ivf rq cosine expectations --- python/python/tests/test_vector_index.py | 30 +++++++++++++------ .../src/index/vector/ivf/partition_serde.rs | 7 ++++- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 2760d55e842..a12ad6a0b43 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1067,7 +1067,7 @@ def test_create_ivf_rq_skip_transpose(): assert stats["indices"][0]["sub_index"]["packed"] is False -def test_create_ivf_rq_multi_bit_searches_l2_and_gates_cosine(): +def test_create_ivf_rq_multi_bit_searches_l2_and_cosine(): ds = lance.write_dataset(create_table(), "memory://") ds = ds.create_index( @@ -1090,14 +1090,26 @@ def test_create_ivf_rq_multi_bit_searches_l2_and_gates_cosine(): assert result.num_rows == 10 cosine_ds = lance.write_dataset(create_table(), "memory://") - with pytest.raises(NotImplementedError, match="num_bits>1 cosine index creation"): - cosine_ds.create_index( - "vector", - index_type="IVF_RQ", - metric="cosine", - num_partitions=4, - num_bits=9, - ) + cosine_ds = cosine_ds.create_index( + "vector", + index_type="IVF_RQ", + metric="cosine", + num_partitions=4, + num_bits=9, + ) + cosine_stats = cosine_ds.stats.index_stats("vector_idx") + assert cosine_stats["indices"][0]["sub_index"]["num_bits"] == 9 + assert cosine_stats["indices"][0]["sub_index"]["query_estimator"] == "raw_query" + + cosine_result = cosine_ds.to_table( + nearest={ + "column": "vector", + "q": np.random.randn(128).astype(np.float32), + "k": 10, + "metric": "cosine", + } + ) + assert cosine_result.num_rows == 10 def test_create_ivf_rq_requires_dim_divisible_by_8(): diff --git a/rust/lance/src/index/vector/ivf/partition_serde.rs b/rust/lance/src/index/vector/ivf/partition_serde.rs index 2cf3719e12a..83ced18c598 100644 --- a/rust/lance/src/index/vector/ivf/partition_serde.rs +++ b/rust/lance/src/index/vector/ivf/partition_serde.rs @@ -1082,6 +1082,11 @@ mod tests { fn test_rabitq_distance_types() { for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot] { let storage = make_rabit_storage_fast(10, 32, dt); + let expected_distance_type = if dt == DistanceType::Cosine { + DistanceType::L2 + } else { + dt + }; let entry = PartitionEntry:: { index: FlatIndex::default(), storage, @@ -1092,7 +1097,7 @@ mod tests { &bytes::Bytes::from(bytes), ) .unwrap(); - assert_eq!(restored.storage.distance_type(), dt); + assert_eq!(restored.storage.distance_type(), expected_distance_type); } } From b7eb77602f59581f161fec8c83e56cd6187dea1e Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Jun 2026 17:15:13 +0800 Subject: [PATCH 10/14] fix(index): reuse supplied ivf rq rotation --- rust/lance-index/src/vector/bq/builder.rs | 116 +++++++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index ea3d9ed2667..8cfe3f3eff1 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -26,7 +26,7 @@ use crate::vector::bq::transform::{ }; use crate::vector::bq::{ RQBuildParams, RQRotationType, rabit_binary_code_bytes, rabit_ex_bits, rabit_ex_code_bytes, - rotation::{apply_fast_rotation, random_fast_rotation_signs}, + rotation::{apply_fast_rotation, fast_rotation_signs_len, random_fast_rotation_signs}, validate_rq_num_bits, }; use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams}; @@ -266,6 +266,65 @@ impl RabitQuantizer { &self.metadata } + fn from_supplied_rotation(params: &RQBuildParams, dim: usize) -> Result> { + let Some(metadata) = params.rotation.as_ref() else { + return Ok(None); + }; + + if metadata.num_bits != params.num_bits { + return Err(Error::invalid_input(format!( + "rabitq_model num_bits={} does not match requested num_bits={}", + metadata.num_bits, params.num_bits + ))); + } + + let rotated_dim = metadata.rotated_dim(); + if rotated_dim != dim { + return Err(Error::invalid_input(format!( + "rabitq_model dimension={} does not match vector dimension={}", + rotated_dim, dim + ))); + } + + match metadata.rotation_type { + RQRotationType::Fast => { + let signs = metadata.fast_rotation_signs.as_ref().ok_or_else(|| { + Error::invalid_input( + "rabitq_model fast rotation is missing fast_rotation_signs".to_string(), + ) + })?; + let expected_len = fast_rotation_signs_len(dim); + if signs.len() != expected_len { + return Err(Error::invalid_input(format!( + "rabitq_model fast_rotation_signs length={} does not match expected length={} for dimension={}", + signs.len(), + expected_len, + dim + ))); + } + } + RQRotationType::Matrix => { + let rotate_mat = metadata.rotate_mat.as_ref().ok_or_else(|| { + Error::invalid_input( + "rabitq_model matrix rotation is missing rotate_mat".to_string(), + ) + })?; + if rotate_mat.len() != dim || rotate_mat.value_length() != dim as i32 { + return Err(Error::invalid_input(format!( + "rabitq_model matrix rotation shape=({}, {}) does not match vector dimension={}", + rotate_mat.len(), + rotate_mat.value_length(), + dim + ))); + } + } + } + + Ok(Some(Self { + metadata: metadata.clone(), + })) + } + #[inline] fn fast_rotation_signs(&self) -> &[u8] { self.metadata @@ -640,6 +699,9 @@ impl Quantization for RabitQuantizer { "vector dimension must be divisible by 8 for IVF_RQ", )); } + if let Some(q) = Self::from_supplied_rotation(params, dim)? { + return Ok(q); + } let q = match data.as_fixed_size_list().value_type() { DataType::Float16 => Self::new_with_rotation::( @@ -949,6 +1011,58 @@ mod tests { ); } + #[test] + fn test_rabit_quantizer_reuses_supplied_rotation() { + let vectors = Float32Array::from(vec![0.0f32; 4 * 32]); + let fsl = FixedSizeListArray::try_new_from_values(vectors, 32).unwrap(); + let supplied = + RabitQuantizer::new_with_rotation::(3, 32, RQRotationType::Fast) + .metadata(None); + let supplied_signs = supplied.fast_rotation_signs.clone(); + + let mut params = RQBuildParams::with_rotation_type(3, RQRotationType::Fast); + params.rotation = Some(supplied); + + let quantizer = RabitQuantizer::build(&fsl, DistanceType::L2, ¶ms).unwrap(); + let metadata = quantizer.metadata_ref(); + assert_eq!(metadata.num_bits, 3); + assert_eq!(metadata.rotation_type, RQRotationType::Fast); + assert_eq!(metadata.fast_rotation_signs, supplied_signs); + } + + #[test] + fn test_rabit_quantizer_validates_supplied_rotation() { + let vectors = Float32Array::from(vec![0.0f32; 4 * 32]); + let fsl = FixedSizeListArray::try_new_from_values(vectors, 32).unwrap(); + let supplied = + RabitQuantizer::new_with_rotation::(3, 32, RQRotationType::Fast) + .metadata(None); + + let mut wrong_num_bits = supplied.clone(); + wrong_num_bits.num_bits = 1; + let mut params = RQBuildParams::with_rotation_type(3, RQRotationType::Fast); + params.rotation = Some(wrong_num_bits); + let err = RabitQuantizer::build(&fsl, DistanceType::L2, ¶ms).unwrap_err(); + assert!( + err.to_string() + .contains("does not match requested num_bits") + ); + + let mut wrong_dim = supplied.clone(); + wrong_dim.code_dim = 64; + let mut params = RQBuildParams::with_rotation_type(3, RQRotationType::Fast); + params.rotation = Some(wrong_dim); + let err = RabitQuantizer::build(&fsl, DistanceType::L2, ¶ms).unwrap_err(); + assert!(err.to_string().contains("does not match vector dimension")); + + let mut wrong_sign_len = supplied; + wrong_sign_len.fast_rotation_signs.as_mut().unwrap().pop(); + let mut params = RQBuildParams::with_rotation_type(3, RQRotationType::Fast); + params.rotation = Some(wrong_sign_len); + let err = RabitQuantizer::build(&fsl, DistanceType::L2, ¶ms).unwrap_err(); + assert!(err.to_string().contains("fast_rotation_signs length")); + } + #[test] fn test_rabit_quantizer_accepts_multi_bit_range() { let vectors = Float32Array::from(vec![0.0f32; 4 * 32]); From 266b784e57c79f9c758673d857b9c8fd5977824c Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Jun 2026 18:00:24 +0800 Subject: [PATCH 11/14] fix(linalg): handle odd dist table code lengths --- python/python/tests/compat/compat_decorator.py | 10 ++++++++-- python/python/tests/compat/test_vector_indices.py | 7 +++++++ rust/lance-linalg/src/simd/dist_table.rs | 8 ++++++-- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/python/python/tests/compat/compat_decorator.py b/python/python/tests/compat/compat_decorator.py index 0ab35672410..fdfe09a6879 100644 --- a/python/python/tests/compat/compat_decorator.py +++ b/python/python/tests/compat/compat_decorator.py @@ -153,6 +153,10 @@ def skip_read_after_current_write(self, version: str) -> bool: """Return True to skip the old-version read after current-version writes.""" return False + def skip_write_after_current_write(self, version: str) -> bool: + """Return True to skip the old-version write after current-version writes.""" + return False + def skip_downgrade(self, version: str) -> bool: """Return True to skip the current-write -> old-read downgrade test.""" return False @@ -333,8 +337,10 @@ def test_func({sig_params}): obj.create() # Old version: verify can read venv = venv_factory.get_venv(version) - venv.execute_method(obj, "check_read", obj.compat_env(version, "check_read")) - venv.execute_method(obj, "check_write", obj.compat_env(version, "check_write")) + if not obj.skip_read_after_current_write(version): + venv.execute_method(obj, "check_read", obj.compat_env(version, "check_read")) + if not obj.skip_write_after_current_write(version): + venv.execute_method(obj, "check_write", obj.compat_env(version, "check_write")) ''' else: # upgrade_downgrade func_body = f''' diff --git a/python/python/tests/compat/test_vector_indices.py b/python/python/tests/compat/test_vector_indices.py index e381a3ce554..5c8535daccc 100644 --- a/python/python/tests/compat/test_vector_indices.py +++ b/python/python/tests/compat/test_vector_indices.py @@ -274,6 +274,13 @@ def current_env(self, method_name: str): return {"LANCE_COMPAT_CURRENT_RUNTIME": "1"} return {} + def skip_write_after_current_write(self, version: str) -> bool: + # Newly written IVF_RQ indexes carry raw-query estimator metadata and + # split-code schema that older runtimes can query but cannot optimize. + # The upgrade_downgrade variant still covers old 1-bit residual-query + # indexes being read and rewritten by the current runtime. + return True + def create(self): """Create dataset with IVF_RQ vector index.""" shutil.rmtree(self.path, ignore_errors=True) diff --git a/rust/lance-linalg/src/simd/dist_table.rs b/rust/lance-linalg/src/simd/dist_table.rs index 646adc6bdcd..addd90381d7 100644 --- a/rust/lance-linalg/src/simd/dist_table.rs +++ b/rust/lance-linalg/src/simd/dist_table.rs @@ -149,6 +149,10 @@ unsafe fn sum_dist_table_32bytes_batch_avx2(codes: &[u8], dist_table: &[u8], dis accu2 = _mm256_add_epi16(accu2, res_hi); accu3 = _mm256_add_epi16(accu3, _mm256_srli_epi16(res_hi, 8)); + if i + 32 >= codes.len() { + continue; + } + // load the left 32 bytes of codes and lut c = _mm256_loadu_si256(codes.as_ptr().add(i + 32) as *const __m256i); lut_vec = _mm256_loadu_si256(dist_table.as_ptr().add(i + 32) as *const __m256i); @@ -352,7 +356,7 @@ mod tests { // directly since that's what the function sees. // code_len=16 → DIM=128, code_len=192 → DIM=1536, // code_len=512 → DIM=4096, code_len=8192 → DIM=65536 - for code_len in [2, 16, 96, 192, 512, 1024, 8192] { + for code_len in [1, 2, 3, 16, 95, 96, 192, 512, 1024, 8192] { let n = BATCH_SIZE; // 32 vectors per batch // Each code byte produces 2 lookups; cap values so @@ -386,7 +390,7 @@ mod tests { use rand::{Rng, SeedableRng}; let mut rng = rand::rngs::StdRng::seed_from_u64(123); - for code_len in [16, 192, 1024] { + for code_len in [1, 3, 16, 191, 192, 1024] { let n = BATCH_SIZE * 10; // 320 vectors = 10 batches let max_val = (u16::MAX as usize / (2 * code_len)).min(255) as u8; From 221301ba332ea4e3fef50301338991116a719a88 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Jun 2026 18:18:25 +0800 Subject: [PATCH 12/14] test(lance): stabilize child input stream timing --- rust/lance/src/io/exec/utils.rs | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/rust/lance/src/io/exec/utils.rs b/rust/lance/src/io/exec/utils.rs index 5def0fb254d..5cd291046e1 100644 --- a/rust/lance/src/io/exec/utils.rs +++ b/rust/lance/src/io/exec/utils.rs @@ -579,11 +579,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])); let n_batches: usize = 3; - // child_delay is intentionally several times larger than transform_delay - // so the assertion tolerates significant `std::thread::sleep` overshoot - // on busy CI runners (we've seen ~2-3x overshoot on macOS Actions). let child_delay = Duration::from_millis(150); - let transform_delay = Duration::from_millis(30); let counter = Arc::new(AtomicUsize::new(0)); let s = schema.clone(); @@ -607,10 +603,7 @@ mod tests { let stream = InstrumentedChildInputStream::new( child, schema, - move |batch| async move { - std::thread::sleep(transform_delay); - Ok(batch) - }, + move |batch| async move { Ok(batch) }, 1, 0, &metrics, @@ -625,16 +618,10 @@ mod tests { .expect("elapsed_compute should be recorded"); let elapsed = Duration::from_nanos(elapsed_ns as u64); - // Expect ~ transform_delay * n. The upper bound is set generously to - // absorb sleep overshoot on slow CI (~4-5x per call) while still - // cleanly rejecting any version that double-counts child poll time, - // which would yield ~ (transform_delay + child_delay) * n. - let upper = Duration::from_millis(400); - assert!( - elapsed >= transform_delay * (n_batches as u32 - 1), - "elapsed_compute={:?} too low; transform time was not measured", - elapsed, - ); + // The transform is immediate, so `elapsed_compute` should stay well + // below even one child poll delay. A version that double-counts child + // input time would include roughly `child_delay * n_batches`. + let upper = child_delay; assert!( elapsed < upper, "elapsed_compute={:?} >= {:?}; child input time was double-counted", From 4910e176a5f9da8f9d3e259ac4ca0eb2bd8d00f1 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Jun 2026 19:52:09 +0800 Subject: [PATCH 13/14] fix(index): address ivf rq review feedback --- docs/src/format/index/vector/index.md | 2 + python/python/tests/test_vector_index.py | 65 +++++++++++------------- rust/lance-linalg/src/simd/dist_table.c | 13 +++-- rust/lance/src/index/vector/ivf/v2.rs | 52 +++++++++++++------ 4 files changed, 80 insertions(+), 52 deletions(-) diff --git a/docs/src/format/index/vector/index.md b/docs/src/format/index/vector/index.md index 48bd27163a0..7aaf9b55996 100644 --- a/docs/src/format/index/vector/index.md +++ b/docs/src/format/index/vector/index.md @@ -198,6 +198,7 @@ Compresses vectors using RabitQ with random rotation and binary quantization for | `_rabit_codes` | list[dimension / 8] | false | Binary quantized codes (1 bit per dimension, packed into bytes) | | `__add_factors` | float32 | false | Additive correction factors for distance computation | | `__scale_factors` | float32 | false | Scale correction factors for distance computation | +| `__error_factors` | float32 | false for `raw_query` | Error factors for raw-query lower-bound pruning | | `__ex_codes` | list[ceil(dimension * (num_bits - 1) / 8)] | false for `num_bits > 1` | Extra RabitQ code bits for multi-bit RQ | | `__add_factors_ex` | float32 | false for `num_bits > 1` | Additive correction factors for ex-code distance computation | | `__scale_factors_ex` | float32 | false for `num_bits > 1` | Scale correction factors for ex-code distance computation | @@ -358,6 +359,7 @@ pa.schema([ pa.field("_rabit_codes", pa.list(pa.uint8(), list_size=16)), # dimension/8 = 128/8 = 16 bytes pa.field("__add_factors", pa.float32()), pa.field("__scale_factors", pa.float32()), + pa.field("__error_factors", pa.float32()), ]) ``` diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index a12ad6a0b43..2e4a3dd4648 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1067,50 +1067,47 @@ def test_create_ivf_rq_skip_transpose(): assert stats["indices"][0]["sub_index"]["packed"] is False -def test_create_ivf_rq_multi_bit_searches_l2_and_cosine(): - ds = lance.write_dataset(create_table(), "memory://") +def _assert_recall_at_least(ds, query, metric=None, k=10, recall_requirement=0.5): + nearest = {"column": "vector", "q": query, "k": k} + if metric is not None: + nearest["metric"] = metric - ds = ds.create_index( - "vector", - index_type="IVF_RQ", - num_partitions=4, - num_bits=9, + gt_ids = ds.to_table(nearest=nearest, columns=["id"])["id"].to_numpy() + create_index_kwargs = { + "index_type": "IVF_RQ", + "num_partitions": 4, + "num_bits": 9, + } + if metric is not None: + create_index_kwargs["metric"] = metric + indexed = ds.create_index("vector", **create_index_kwargs) + result_ids = indexed.to_table(nearest=nearest, columns=["id"])["id"].to_numpy() + + assert result_ids.shape[0] == k + recall = len(set(gt_ids) & set(result_ids)) / k + assert recall >= recall_requirement, ( + f"recall={recall}, gt={gt_ids}, result={result_ids}" ) + return indexed + + +def test_create_ivf_rq_multi_bit_searches_l2_and_cosine(): + rng = np.random.default_rng(42) + mat = rng.standard_normal((1000, 128)).astype(np.float32) + tbl = vec_to_table(data=mat).append_column("id", pa.array(range(len(mat)))) + + ds = lance.write_dataset(tbl, "memory://") + ds = _assert_recall_at_least(ds, mat[0]) stats = ds.stats.index_stats("vector_idx") assert stats["indices"][0]["sub_index"]["num_bits"] == 9 assert stats["indices"][0]["sub_index"]["query_estimator"] == "raw_query" - result = ds.to_table( - nearest={ - "column": "vector", - "q": np.random.randn(128).astype(np.float32), - "k": 10, - } - ) - assert result.num_rows == 10 - - cosine_ds = lance.write_dataset(create_table(), "memory://") - cosine_ds = cosine_ds.create_index( - "vector", - index_type="IVF_RQ", - metric="cosine", - num_partitions=4, - num_bits=9, - ) + cosine_ds = lance.write_dataset(tbl, "memory://") + cosine_ds = _assert_recall_at_least(cosine_ds, mat[1], metric="cosine") cosine_stats = cosine_ds.stats.index_stats("vector_idx") assert cosine_stats["indices"][0]["sub_index"]["num_bits"] == 9 assert cosine_stats["indices"][0]["sub_index"]["query_estimator"] == "raw_query" - cosine_result = cosine_ds.to_table( - nearest={ - "column": "vector", - "q": np.random.randn(128).astype(np.float32), - "k": 10, - "metric": "cosine", - } - ) - assert cosine_result.num_rows == 10 - def test_create_ivf_rq_requires_dim_divisible_by_8(): vectors = np.zeros((1000, 30), dtype=np.float32).tolist() diff --git a/rust/lance-linalg/src/simd/dist_table.c b/rust/lance-linalg/src/simd/dist_table.c index 9e7fc2b2205..e8be8e52068 100644 --- a/rust/lance-linalg/src/simd/dist_table.c +++ b/rust/lance-linalg/src/simd/dist_table.c @@ -23,8 +23,15 @@ void sum_4bit_dist_table_32bytes_batch_avx512(const uint8_t *codes, __m512i accu3 = _mm512_setzero_si512(); for (size_t i = 0; i < code_length; i += 64) { - c = _mm512_loadu_si512(&codes[i]); - lut = _mm512_loadu_si512(&dist_table[i]); + const size_t remaining = code_length - i; + if (remaining >= 64) { + c = _mm512_loadu_si512(&codes[i]); + lut = _mm512_loadu_si512(&dist_table[i]); + } else { + const __mmask64 load_mask = (UINT64_C(1) << remaining) - 1; + c = _mm512_maskz_loadu_epi8(load_mask, &codes[i]); + lut = _mm512_maskz_loadu_epi8(load_mask, &dist_table[i]); + } lo = _mm512_and_si512(c, lo_mask); hi = _mm512_and_si512(_mm512_srli_epi16(c, 4), lo_mask); @@ -50,4 +57,4 @@ void sum_4bit_dist_table_32bytes_batch_avx512(const uint8_t *codes, ret = _mm512_add_epi16(ret, _mm512_shuffle_i64x2(ret1, ret2, 0xDD)); _mm512_storeu_si512(dists, ret); -} \ No newline at end of file +} diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index c415a841cff..6218d39fbad 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -38,8 +38,8 @@ use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::{LocalMetricsCollector, MetricsCollector, NoOpMetricsCollector}; use lance_index::vector::VectorIndexCacheEntry; use lance_index::vector::bq::builder::RabitQuantizer; -use lance_index::vector::bq::rabit_ex_bits; -use lance_index::vector::bq::storage::RabitQueryEstimator; +use lance_index::vector::bq::storage::{RabitQueryEstimator, SEGMENT_NUM_CODES}; +use lance_index::vector::bq::{rabit_ex_bits, rabit_ex_code_bytes}; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::graph::OrderedNode; use lance_index::vector::hnsw::HNSW; @@ -163,6 +163,19 @@ fn rabit_ex_dist_table_len(dim: usize, num_bits: u8) -> usize { .unwrap_or(dim * 256) } +fn rabit_u8_scratch_len(dim: usize, num_bits: u8) -> usize { + let binary_dist_table_len = dim * 4; + let ex_dist_table_len = rabit_ex_bits(num_bits) + .ok() + .and_then(|ex_bits| match ex_bits { + 2 | 4 | 8 => rabit_ex_code_bytes(dim, ex_bits).ok(), + _ => None, + }) + .map(|ex_code_len| ex_code_len * 2 * SEGMENT_NUM_CODES) + .unwrap_or_default(); + binary_dist_table_len.max(ex_dist_table_len) +} + impl DeepSizeOf for IvfIndexState { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { self.index_file_path.deep_size_of_children(context) @@ -921,9 +934,15 @@ impl IVFIndex { let dim = ivf.dimension(); let dist_table_len = dim * 4; - let ex_dist_table_len = match storage.quantizer() { - Ok(Quantizer::Rabit(rq)) => rabit_ex_dist_table_len(dim, rq.metadata_ref().num_bits), - _ => dim * 256, + let (ex_dist_table_len, u8_scratch_len) = match storage.quantizer() { + Ok(Quantizer::Rabit(rq)) => { + let num_bits = rq.metadata_ref().num_bits; + ( + rabit_ex_dist_table_len(dim, num_bits), + rabit_u8_scratch_len(dim, num_bits), + ) + } + _ => (dim * 256, dim * 32), }; let max_partition_len = ivf.lengths.iter().copied().max().unwrap_or_default() as usize; @@ -931,7 +950,7 @@ impl IVFIndex { max_partition_len, dim + dist_table_len + ex_dist_table_len, max_partition_len, - dist_table_len, + u8_scratch_len, ) } @@ -1959,6 +1978,17 @@ mod tests { assert_eq!(super::rabit_ex_dist_table_len(dim, 9), dim * 256); } + #[test] + fn test_rabit_u8_scratch_len_includes_ex_fastscan_tables() { + let dim = 960; + + assert_eq!(super::rabit_u8_scratch_len(dim, 1), dim * 4); + assert_eq!(super::rabit_u8_scratch_len(dim, 3), dim * 8); + assert_eq!(super::rabit_u8_scratch_len(dim, 5), dim * 16); + assert_eq!(super::rabit_u8_scratch_len(dim, 7), dim * 4); + assert_eq!(super::rabit_u8_scratch_len(dim, 9), dim * 32); + } + async fn generate_test_dataset( test_uri: &str, range: Range, @@ -4239,15 +4269,7 @@ mod tests { assert!(schema.field(EX_ADD_FACTORS_COLUMN).is_some()); assert!(schema.field(EX_SCALE_FACTORS_COLUMN).is_some()); - let query = vectors.value(0); - let results = dataset - .scan() - .nearest("vector", query.as_primitive::(), 10) - .unwrap() - .try_into_batch() - .await - .unwrap(); - assert_eq!(results.num_rows(), 10); + test_recall::(params, 4, 0.5, "vector", &dataset, vectors).await; } #[rstest] From 0a258b754c443b08a6f7dcc884ac8f163a3732d4 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 9 Jun 2026 12:11:00 +0800 Subject: [PATCH 14/14] fix(index): skip unsafe ivf rq downgrade reads --- python/python/tests/compat/test_vector_indices.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/python/tests/compat/test_vector_indices.py b/python/python/tests/compat/test_vector_indices.py index 5c8535daccc..e97d6be8bf6 100644 --- a/python/python/tests/compat/test_vector_indices.py +++ b/python/python/tests/compat/test_vector_indices.py @@ -274,13 +274,16 @@ def current_env(self, method_name: str): return {"LANCE_COMPAT_CURRENT_RUNTIME": "1"} return {} - def skip_write_after_current_write(self, version: str) -> bool: + def skip_read_after_current_write(self, version: str) -> bool: # Newly written IVF_RQ indexes carry raw-query estimator metadata and - # split-code schema that older runtimes can query but cannot optimize. + # split-code schema that older runtimes cannot query or optimize safely. # The upgrade_downgrade variant still covers old 1-bit residual-query # indexes being read and rewritten by the current runtime. return True + def skip_write_after_current_write(self, version: str) -> bool: + return True + def create(self): """Create dataset with IVF_RQ vector index.""" shutil.rmtree(self.path, ignore_errors=True)