diff --git a/crates/ole-core/Cargo.toml b/crates/ole-core/Cargo.toml index 8825fcb5..2282787f 100644 --- a/crates/ole-core/Cargo.toml +++ b/crates/ole-core/Cargo.toml @@ -12,7 +12,6 @@ test-utils = [] [dependencies] rand.workspace = true -itybity.workspace = true thiserror.workspace = true serde = { workspace = true, features = ["derive"] } hybrid-array = { workspace = true, features = ["serde"] } diff --git a/crates/ole-core/src/lib.rs b/crates/ole-core/src/lib.rs index 1eb2c539..063aa94a 100644 --- a/crates/ole-core/src/lib.rs +++ b/crates/ole-core/src/lib.rs @@ -31,7 +31,6 @@ pub use role::{ROLEReceiver, ROLEReceiverOutput, ROLESender, ROLESenderOutput}; pub use sender::{Sender, SenderError}; use hybrid_array::Array; -use itybity::ToBits; use mpz_fields::Field; use serde::{Deserialize, Serialize}; @@ -107,31 +106,41 @@ where /// Creates a new OLE share for the receiver. /// + /// The original ROT choice bits must be used directly rather than + /// converting to a field element and back, because for prime fields where + /// `2^k > p`, `from_lsb0_iter` reduces mod p which can produce different + /// bits than the ROT choices. The multiplicative share `b` is computed + /// from the choice bits using field arithmetic, which naturally reduces. + /// /// # Arguments /// - /// * `input` - Input value, `b`. - /// * `masks` - Chosen correlation masks. + /// * `choices` - Original ROT choice bits (LSB-first). + /// * `masks` - Chosen correlation masks from ROT. /// * `corr` - Masked correlation from the sender. #[inline] pub(crate) fn new_ole_receiver( - input: F, + choices: &[bool], masks: Array, corr: MaskedCorrelation, ) -> Self { - let delta_i = input.iter_lsb0(); + let delta_i = choices.iter(); let t_delta_i = masks.iter(); let corr = corr.0.iter(); - // Compute additive share, `y`. - let add = delta_i.zip(corr).zip(t_delta_i).enumerate().fold( - F::zero(), - |acc, (i, ((delta, &u), &t))| { + // Compute additive share `y` and multiplicative share `b` together. + let (add, mul) = delta_i.zip(corr).zip(t_delta_i).enumerate().fold( + (F::zero(), F::zero()), + |(add, mul), (i, ((&delta, &u), &t))| { + let two_pow_i = F::two_pow(i as u32); let delta = if delta { F::one() } else { F::zero() }; - acc + F::two_pow(i as u32) * (delta * u + t) + ( + add + two_pow_i * (delta * u + t), + mul + two_pow_i * delta, + ) }, ); - Self { add, mul: input } + Self { add, mul } } /// Adjusts the multiplicative share to the target. @@ -174,6 +183,41 @@ mod tests { test_ole::(); } + /// Verifies OLE correctness when receiver's ROT choice bits represent a + /// value >= p (the P256 field prime). Before the fix, `from_lsb0_iter` + /// would panic (ark-ff rejects out-of-range BigInt), and even with + /// reduction it would produce bits mismatched with the ROT choices. + #[test] + fn test_ole_p256_choices_exceed_prime() { + use rand::Rng as _; + + let mut rng = StdRng::seed_from_u64(42); + + // All-ones is 2^256 - 1, which is >= p for P256. + let choices_all_ones: Vec = vec![true; 256]; + + // Build sender masks and receiver correlation from random ROT keys, + // simulating what the ROT protocol would produce. + let sender_input: P256 = rng.random(); + let masks_pairs: Array<[P256; 2], ::BitSize> = + Array::from_fn(|_| [rng.random(), rng.random()]); + + // Receiver's ROT messages correspond to the original choice bits. + let receiver_masks: Array::BitSize> = Array::from_fn(|i| { + if choices_all_ones[i] { + masks_pairs[i][1] + } else { + masks_pairs[i][0] + } + }); + + let (sender_share, corr) = OLEShare::new_ole_sender(sender_input, masks_pairs); + let receiver_share = + OLEShare::new_ole_receiver(&choices_all_ones, receiver_masks, corr); + + assert_ole(sender_share, receiver_share); + } + fn test_ole() where StandardUniform: Distribution, diff --git a/crates/ole-core/src/receiver.rs b/crates/ole-core/src/receiver.rs index 46a6b1a9..1002506d 100644 --- a/crates/ole-core/src/receiver.rs +++ b/crates/ole-core/src/receiver.rs @@ -103,7 +103,7 @@ where .zip(masks) .map(|((bits, corr), mask)| { OLEShare::new_ole_receiver( - F::from_lsb0_iter(bits.iter().copied()), + bits, Array::::try_from(corr) .expect("slice should have length of bit size of field element"), mask,