diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem.lean index 206d9295..d987542e 100644 --- a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem.lean +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem.lean @@ -1 +1,27 @@ import LibcruxIotMlKem.Extraction.Funs +import LibcruxIotMlKem.Util.SliceSpecs +import LibcruxIotMlKem.Util.LoopSpecs +import LibcruxIotMlKem.Util.CreateI +import LibcruxIotMlKem.Spec +import LibcruxIotMlKem.Spec.Pure +import LibcruxIotMlKem.Spec.Commute +import LibcruxIotMlKem.Spec.StateIso +import LibcruxIotMlKem.Spec.AlgEquiv +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Sampling +import LibcruxIotMlKem.Serialize +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Matrix.ComputeMessage.FC +import LibcruxIotMlKem.Matrix.ComputeVectorU.FC +import LibcruxIotMlKem.Matrix.ComputeRingElementV.FC diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/InvertNtt.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/InvertNtt.lean new file mode 100644 index 00000000..2fb32a02 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/InvertNtt.lean @@ -0,0 +1,3813 @@ +/- + # `InvertNtt.lean` — extracted from `FCTargets.lean` §invert_ntt. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.InvertNtt +open libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §L3i — Inverse-NTT driver loops. + + Mirror of §L3 for the inverse direction. Each `invert_ntt_at_layer_N` + is a 16-iter loop over `round ∈ [0, 16)` that DECREMENTS `zeta_i` by + `4` (layer 1), `2` (layer 2), or `1` (layer 3) per chunk. Layer 4+ + is a nested cross-chunk butterfly (deferred to Task H). Each chunk + dispatches to the corresponding `inv_ntt_layer_N_step_fc` (FCTargets + L2.9-11, just closed in Task E). + + Top-level composer `invert_ntt_montgomery` (Task I) calls these in + sequence: layer 1 with `zeta_i = 64`, layer 2 with `zeta_i = 32`, + layer 3 with `zeta_i = 16`, etc. The FC posts expose both the + output `zeta_i.val` (so the composer can chain) and the + `Spec.invert_ntt_layer_N_pure` equation. -/ + +/-! ### L3i.1 — Loop scaffolding for `invert_ntt_at_layer_1_portable_fc`. + + Mirror of §L3.1 scaffolding but with `zeta_i` DECREASING (4 per chunk) + and reads in reverse order (`zeta_i - 4k - {1,2,3,4}`). The Acc type + `(Std.Usize, PolynomialRingElement)` matches the impl's loop state. -/ + +namespace Layer1FC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Local `usize_sub_ok_eq` helper (mirror of `Layer1FC.usize_add_ok_eq` but for sub). -/ +theorem usize_sub_ok_eq (x y : Std.Usize) + (h_ge : y.val ≤ x.val) : + ∃ z : Std.Usize, (x - y : Result Std.Usize) = .ok z ∧ z.val = x.val - y.val := by + have hT := Std.Usize.sub_spec h_ge + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v.1⟩ + +/-- Step-local accumulator. -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `invert_ntt_at_layer_1_portable_fc`. + Tracks DECREASING `zeta_i`: at outer iter `k`, `acc.1.val = zeta_i_0.val - 4 * k.val`. + Chunks `< k.val` are FC-equal to the per-chunk inverse step; chunks `≥ k.val` + are unchanged from `re`. Per-lane output bound `≤ 3328` on every chunk. -/ +def inv + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = zeta_i_0.val - 4 * k.val + ∧ (∀ j : Nat, j < k.val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_1_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 1)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 2)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 3)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 4))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv zeta_i_0 re iter'.start acc').holds + | .done y => (inv zeta_i_0 re 16#usize y).holds + +end Layer1FC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the inverse layer-1 driver. Given a valid + loop state `(acc, k)` with `k.val < 16`, decreases `zeta_i` by 4 and records + the FC equation for chunk `k.val`, leaving chunks `> k.val` unchanged. -/ +theorem invert_ntt_at_layer_1_step_lemma_fc + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 13312) + (h_zeta_lo : 64 ≤ zeta_i_0.val) + (h_zeta_hi : zeta_i_0.val ≤ 128) + (acc : Layer1FC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (Layer1FC.inv zeta_i_0 re k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer1FC.step_post zeta_i_0 re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone, h_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some round = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_um2 : (2#usize : Std.Usize).val = 2 := rfl + have h_um3 : (3#usize : Std.Usize).val = 3 := rfl + -- acc.1.val = zeta_i_0.val - 4*k.val, with k.val ≤ 15 ⇒ acc.1.val ≥ zeta_i_0.val - 60 ≥ 4. + have h_acc1_ge_4 : 4 ≤ acc.1.val := by + rw [h_zeta_acc] + have h_k_le_15 : k.val ≤ 15 := by omega + omega + -- (1) `zeta_i - 1` ⇒ `zi1` with `zi1.val = acc.1.val - 1`. + have h_z_ge : (1#usize : Std.Usize).val ≤ acc.1.val := by rw [h_um]; omega + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + Layer1FC.usize_sub_ok_eq acc.1 1#usize h_z_ge + have h_zi1_val_arith : zi1.val = acc.1.val - 1 := by rw [h_zi1_val, h_um] + -- zi1.val < 128: zi1.val = acc.1.val - 1 = zeta_i_0 - 4k - 1 ≤ zeta_i_0 - 1 ≤ 127. + have h_zi1_lt : zi1.val < 128 := by + rw [h_zi1_val_arith, h_zeta_acc]; omega + -- (2) `index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx]; rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + -- (3) `polynomial.zeta zi1` ⇒ `z1` at index `zi1.val = acc.1.val - 1`. + obtain ⟨z1, h_z1_eq, h_z1_v, h_z1_bd, h_z1_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi1 h_zi1_lt) + -- (4) `zi1 - 1` ⇒ `zi3.val = zi1.val - 1 = acc.1.val - 2`. + have h_zi3_ge : (1#usize : Std.Usize).val ≤ zi1.val := by + rw [h_um, h_zi1_val_arith]; omega + obtain ⟨zi3, h_zi3_eq, h_zi3_val⟩ := + Layer1FC.usize_sub_ok_eq zi1 1#usize h_zi3_ge + have h_zi3_val_arith : zi3.val = acc.1.val - 2 := by + rw [h_zi3_val, h_um, h_zi1_val_arith]; omega + have h_zi3_lt : zi3.val < 128 := by + rw [h_zi3_val_arith, h_zeta_acc]; omega + -- (5) `polynomial.zeta zi3` ⇒ `z2`. + obtain ⟨z2, h_z2_eq, h_z2_v, h_z2_bd, h_z2_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi3 h_zi3_lt) + -- (6) `zi1 - 2` ⇒ `zi5.val = zi1.val - 2 = acc.1.val - 3`. + have h_zi5_ge : (2#usize : Std.Usize).val ≤ zi1.val := by + rw [h_um2, h_zi1_val_arith]; omega + obtain ⟨zi5, h_zi5_eq, h_zi5_val⟩ := + Layer1FC.usize_sub_ok_eq zi1 2#usize h_zi5_ge + have h_zi5_val_arith : zi5.val = acc.1.val - 3 := by + rw [h_zi5_val, h_um2, h_zi1_val_arith]; omega + have h_zi5_lt : zi5.val < 128 := by + rw [h_zi5_val_arith, h_zeta_acc]; omega + -- (7) `polynomial.zeta zi5` ⇒ `z3`. + obtain ⟨z3, h_z3_eq, h_z3_v, h_z3_bd, h_z3_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi5 h_zi5_lt) + -- (8) `zi1 - 3` ⇒ `zi7.val = zi1.val - 3 = acc.1.val - 4`. + have h_zi7_ge : (3#usize : Std.Usize).val ≤ zi1.val := by + rw [h_um3, h_zi1_val_arith]; omega + obtain ⟨zi7, h_zi7_eq, h_zi7_val⟩ := + Layer1FC.usize_sub_ok_eq zi1 3#usize h_zi7_ge + have h_zi7_val_arith : zi7.val = acc.1.val - 4 := by + rw [h_zi7_val, h_um3, h_zi1_val_arith]; omega + have h_zi7_lt : zi7.val < 128 := by + rw [h_zi7_val_arith, h_zeta_acc]; omega + -- (9) `polynomial.zeta zi7` ⇒ `z4`. + obtain ⟨z4, h_z4_eq, h_z4_v, h_z4_bd, h_z4_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi7 h_zi7_lt) + -- (10) `inv_ntt_layer_1_step t z1 z2 z3 z4`. Pre: t's lanes ≤ 13312 via h_pre + undone. + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 13312 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + -- @[reducible] portable_ops_inst forwards to vector.portable.ntt.inv_ntt_layer_1_step. + obtain ⟨t1, h_t1_eq, h_t1_lift, h_t1_bnd⟩ := + triple_exists_ok_fc (inv_ntt_layer_1_step_fc t z1 z2 z3 z4 + ⟨h_z1_bd, h_z2_bd, h_z3_bd, h_z4_bd⟩ h_t_bd) + -- Compose entire body. Loop output for `cont` is `(iter', zi7, re')` (3-tuple in + -- the impl's loop-state, the Acc holds the latter two as a pair). + set acc' : Layer1FC.Acc := (zi7, { coefficients := acc.2.coefficients.set k t1 }) + with hacc'_def + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [Aeneas.Std.bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq, h_zi3_eq, + h_z2_eq, h_zi5_eq, h_z3_eq, h_zi7_eq, h_z4_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_1_step t z1 z2 z3 z4 + Result.ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi7, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq]; rfl + apply triple_of_ok_fc h_body + show Layer1FC.step_post zeta_i_0 re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer1FC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (Layer1FC.inv zeta_i_0 re s acc').holds + have h_inv_pure : + acc'.1.val = zeta_i_0.val - 4 * s.val + ∧ (∀ j : Nat, j < s.val → + lift_chunk (acc'.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_1_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 1)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 2)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 3)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 4))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < s.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc'.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · -- acc'.1 = zi7, zi7.val = acc.1.val - 4 = zeta_i_0.val - 4 * (k.val + 1). + show zi7.val = zeta_i_0.val - 4 * s.val + rw [h_zi7_val_arith, h_zeta_acc, hs_val] + have h_k_le_15 : k.val ≤ 15 := by omega + omega + · -- All j < s.val are FC-equal. + intro j hj + rw [hs_val] at hj + show lift_chunk ((acc.2.coefficients.set k t1).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: unchanged by set; use h_acc_done. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k + · -- j = k.val: it's t1; use h_t1_lift + h_t_eq + zeta_lift identities. + subst hj_eq_k + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq_val, h_t1_lift, h_t_eq] + -- Need: Spec.chunk_inv_ntt_layer_1_step_pure (lift_chunk re.coefficients[k]) + -- (lift_fe_mont z1..z4) = Spec.chunk_inv_ntt_layer_1_step_pure (...) + -- (Spec.zeta_at (zeta_i_0 - 4k - 1..4)). + have h_k_le_15 : k.val ≤ 15 := by omega + have h_zi1_z : zi1.val = zeta_i_0.val - 4 * k.val - 1 := by + rw [h_zi1_val_arith, h_zeta_acc] + have h_zi3_z : zi3.val = zeta_i_0.val - 4 * k.val - 2 := by + rw [h_zi3_val_arith, h_zeta_acc] + have h_zi5_z : zi5.val = zeta_i_0.val - 4 * k.val - 3 := by + rw [h_zi5_val_arith, h_zeta_acc] + have h_zi7_z : zi7.val = zeta_i_0.val - 4 * k.val - 4 := by + rw [h_zi7_val_arith, h_zeta_acc] + rw [show lift_fe_mont z1 = Spec.zeta_at (zeta_i_0.val - 4 * k.val - 1) + from by rw [← h_zi1_z]; exact h_z1_lift] + rw [show lift_fe_mont z2 = Spec.zeta_at (zeta_i_0.val - 4 * k.val - 2) + from by rw [← h_zi3_z]; exact h_z2_lift] + rw [show lift_fe_mont z3 = Spec.zeta_at (zeta_i_0.val - 4 * k.val - 3) + from by rw [← h_zi5_z]; exact h_z3_lift] + rw [show lift_fe_mont z4 = Spec.zeta_at (zeta_i_0.val - 4 * k.val - 4) + from by rw [← h_zi7_z]; exact h_z4_lift] + · -- All j ≥ s.val are unchanged. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- Per-lane output bound on every chunk. + intro c hc ℓ hℓ + show ((acc'.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328 + show (((acc.2.coefficients.set k t1).val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328 + by_cases h_ck : c = k.val + · -- At touched chunk: value is t1, bounded by h_t1_bnd. + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[c]! = t1 := by + rw [h_ck] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq_val]; exact h_t1_bnd ℓ hℓ + · -- At untouched chunk: value preserved from acc.2, bounded by h_acc_bnd. + have h_ne : k.val ≠ c := Ne.symm h_ck + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[c]! = acc.2.coefficients.val[c]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k c t1 h_ne + rw [h_set_ne_val]; exact h_acc_bnd c (by omega) ℓ hℓ + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer1FC.step_post zeta_i_0 re k (.done acc) + unfold Layer1FC.step_post + show (Layer1FC.inv zeta_i_0 re 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + acc.1.val = zeta_i_0.val - 4 * (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_1_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 1)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 2)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 3)) + (Spec.zeta_at (zeta_i_0.val - 4 * j - 4))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · rw [h_zeta_acc, hk_eq, h16] + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + · intro c hc ℓ hℓ; exact h_acc_bnd c (by omega) ℓ hℓ + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L3i.1 — `invert_ntt_at_layer_1` driver: 16-chunk loop, per-chunk + 4-zeta-lookup decreasing `zeta_i` by 4. Post: `p.1.val = zeta_i.val - 64` + (the output zeta_i, for composition) and `lift_poly p.2 = + Spec.invert_ntt_layer_1_pure (lift_poly re) zeta_i`. + + Tightening preconditions (added by the proof author): + - `h_zeta_lo : 64 ≤ zeta_i.val` — needed for the 4 subtractions per chunk + to succeed without Nat underflow at every iter (worst case at iter k=15 + is `acc.1.val = zeta_i - 60`, and we further subtract 4). + - `h_zeta_hi : zeta_i.val ≤ 128` — so `polynomial.zeta` index (worst case + `zeta_i - 1` at iter k=0) is `< 128`. + - `h_bnd` per-lane bound `≤ 13312` on `re`'s chunks — matches + `inv_ntt_layer_1_step_fc`'s precondition. The loop invariant tracks + only PROCESSED chunks as ≤ 3328; the initial state at `k = 0` is + vacuously satisfied (no processed chunks yet). Each iteration's + `inv_ntt_layer_1_step_fc` establishes the bound at the touched chunk; + at k=16 all chunks are covered, giving the output bound ≤ 3328. -/ +@[spec high] +theorem invert_ntt_at_layer_1_portable_fc + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312) + (h_zeta_lo : 64 ≤ zeta_i.val) + (h_zeta_hi : zeta_i.val ≤ 128) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1 + (vectortraitsOperationsInst := portable_ops_inst) zeta_i re + ⦃ ⇓ p => ⌜ p.1.val = zeta_i.val - 64 + ∧ lift_poly p.2 = Spec.invert_ntt_layer_1_pure (lift_poly re) zeta_i + ∧ (∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1 + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_1_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + iter1 acc1.1 acc1.2) + (β := Layer1FC.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer1FC.inv zeta_i re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_, ?_⟩ + · -- zeta-thread invariant at k=0: zeta_i = zeta_i - 4*0. + show zeta_i.val = zeta_i.val - 4 * (0#usize : Std.Usize).val + show zeta_i.val = zeta_i.val - 4 * 0 + omega + · intro j hj + exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _ + trivial + · -- Initial bound: vacuous at k=0 (no processed chunks yet). + intro c hc + exact absurd hc (Nat.not_lt_zero c)) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations + zeta_i = zeta_i_0 - 64. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Layer1FC.inv zeta_i re 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + r.1.val = zeta_i.val - 4 * (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_1_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i.val - 4 * j - 1)) + (Spec.zeta_at (zeta_i.val - 4 * j - 2)) + (Spec.zeta_at (zeta_i.val - 4 * j - 3)) + (Spec.zeta_at (zeta_i.val - 4 * j - 4))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((r.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer1FC.inv] using h_inv_holds + obtain ⟨h_zeta_eq, h_done, _h_undone, h_done_bnd⟩ := h_inv + have h16 : (16#usize : Std.Usize).val = 16 := rfl + refine ⟨?_, ?_, ?_⟩ + · -- p.1.val = zeta_i.val - 64. + rw [h_zeta_eq, h16] + · -- The chunks equation. + unfold Spec.invert_ntt_layer_1_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_inv_ntt_layer_1_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val - 4 * k - 1)) + (Spec.zeta_at (zeta_i.val - 4 * k - 2)) + (Spec.zeta_at (zeta_i.val - 4 * k - 3)) + (Spec.zeta_at (zeta_i.val - 4 * k - 4)))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.2.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_inv_ntt_layer_1_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val - 4 * k - 1)) + (Spec.zeta_at (zeta_i.val - 4 * k - 2)) + (Spec.zeta_at (zeta_i.val - 4 * k - 3)) + (Spec.zeta_at (zeta_i.val - 4 * k - 4))))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc re k hk] + exact (h_done k hk).symm + have h_final := flatten_chunks_eq_lift_poly_fc r.2 chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Per-lane output bound from the loop invariant's strengthened conjunct. + exact h_done_bnd + · -- Step lemma application: dispatch invert_ntt_at_layer_1_step_lemma_fc. + intro acc k _h_ge h_le hinv + have h_step := invert_ntt_at_layer_1_step_lemma_fc zeta_i re h_bnd h_zeta_lo h_zeta_hi + acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer1FC.step_post zeta_i re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer1FC.step_post] using hP + · have hP : Layer1FC.step_post zeta_i re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer1FC.step_post] using hP + +/-! ### L3i.2 — Loop scaffolding for `invert_ntt_at_layer_2_portable_fc`. + + Mirror of §L3i.1 scaffolding but with `zeta_i` DECREASING (2 per chunk) + and reads in reverse order (`zeta_i - 2k - {1,2}`). -/ + +namespace Layer2FC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Local `usize_sub_ok_eq` helper (mirror of `Layer1FC.usize_sub_ok_eq`). -/ +theorem usize_sub_ok_eq (x y : Std.Usize) + (h_ge : y.val ≤ x.val) : + ∃ z : Std.Usize, (x - y : Result Std.Usize) = .ok z ∧ z.val = x.val - y.val := by + have hT := Std.Usize.sub_spec h_ge + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v.1⟩ + +/-- Step-local accumulator. -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `invert_ntt_at_layer_2_portable_fc`. + Tracks DECREASING `zeta_i`: at outer iter `k`, `acc.1.val = zeta_i_0.val - 2 * k.val`. + Chunks `< k.val` are FC-equal to the per-chunk inverse step; chunks `≥ k.val` + are unchanged from `re`. Per-lane output bound `≤ 3328` on every chunk. -/ +def inv + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = zeta_i_0.val - 2 * k.val + ∧ (∀ j : Nat, j < k.val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_2_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val - 2 * j - 1)) + (Spec.zeta_at (zeta_i_0.val - 2 * j - 2))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv zeta_i_0 re iter'.start acc').holds + | .done y => (inv zeta_i_0 re 16#usize y).holds + +end Layer2FC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the inverse layer-2 driver. Given a valid + loop state `(acc, k)` with `k.val < 16`, decreases `zeta_i` by 2 and records + the FC equation for chunk `k.val`, leaving chunks `> k.val` unchanged. -/ +theorem invert_ntt_at_layer_2_step_lemma_fc + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 13312) + (h_zeta_lo : 32 ≤ zeta_i_0.val) + (h_zeta_hi : zeta_i_0.val ≤ 128) + (acc : Layer2FC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (Layer2FC.inv zeta_i_0 re k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer2FC.step_post zeta_i_0 re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone, h_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some round = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + have h_um : (1#usize : Std.Usize).val = 1 := rfl + -- acc.1.val = zeta_i_0.val - 2*k.val, with k.val ≤ 15 ⇒ acc.1.val ≥ zeta_i_0.val - 30 ≥ 2. + have h_acc1_ge_2 : 2 ≤ acc.1.val := by + rw [h_zeta_acc] + have h_k_le_15 : k.val ≤ 15 := by omega + omega + -- (1) `zeta_i - 1` ⇒ `zi1` with `zi1.val = acc.1.val - 1`. + have h_z_ge : (1#usize : Std.Usize).val ≤ acc.1.val := by rw [h_um]; omega + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + Layer2FC.usize_sub_ok_eq acc.1 1#usize h_z_ge + have h_zi1_val_arith : zi1.val = acc.1.val - 1 := by rw [h_zi1_val, h_um] + -- zi1.val < 128: zi1.val = acc.1.val - 1 = zeta_i_0 - 2k - 1 ≤ zeta_i_0 - 1 ≤ 127. + have h_zi1_lt : zi1.val < 128 := by + rw [h_zi1_val_arith, h_zeta_acc]; omega + -- (2) `index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx]; rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + -- (3) `polynomial.zeta zi1` ⇒ `z1` at index `zi1.val = acc.1.val - 1`. + obtain ⟨z1, h_z1_eq, h_z1_v, h_z1_bd, h_z1_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi1 h_zi1_lt) + -- (4) `zi1 - 1` ⇒ `zi3.val = zi1.val - 1 = acc.1.val - 2`. + have h_zi3_ge : (1#usize : Std.Usize).val ≤ zi1.val := by + rw [h_um, h_zi1_val_arith]; omega + obtain ⟨zi3, h_zi3_eq, h_zi3_val⟩ := + Layer2FC.usize_sub_ok_eq zi1 1#usize h_zi3_ge + have h_zi3_val_arith : zi3.val = acc.1.val - 2 := by + rw [h_zi3_val, h_um, h_zi1_val_arith]; omega + have h_zi3_lt : zi3.val < 128 := by + rw [h_zi3_val_arith, h_zeta_acc]; omega + -- (5) `polynomial.zeta zi3` ⇒ `z2`. + obtain ⟨z2, h_z2_eq, h_z2_v, h_z2_bd, h_z2_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi3 h_zi3_lt) + -- (6) `inv_ntt_layer_2_step t z1 z2`. Pre: t's lanes ≤ 13312 via h_pre + undone. + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 13312 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + -- @[reducible] portable_ops_inst forwards to vector.portable.ntt.inv_ntt_layer_2_step. + obtain ⟨t1, h_t1_eq, h_t1_lift, h_t1_bnd⟩ := + triple_exists_ok_fc (inv_ntt_layer_2_step_fc t z1 z2 + ⟨h_z1_bd, h_z2_bd⟩ h_t_bd) + -- Compose entire body. Loop output for `cont` is `(iter', zi3, re')`. + set acc' : Layer2FC.Acc := (zi3, { coefficients := acc.2.coefficients.set k t1 }) + with hacc'_def + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [Aeneas.Std.bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq, h_zi3_eq, + h_z2_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_2_step t z1 z2 + Result.ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi3, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq]; rfl + apply triple_of_ok_fc h_body + show Layer2FC.step_post zeta_i_0 re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer2FC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (Layer2FC.inv zeta_i_0 re s acc').holds + have h_inv_pure : + acc'.1.val = zeta_i_0.val - 2 * s.val + ∧ (∀ j : Nat, j < s.val → + lift_chunk (acc'.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_2_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val - 2 * j - 1)) + (Spec.zeta_at (zeta_i_0.val - 2 * j - 2))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc'.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · -- acc'.1 = zi3, zi3.val = acc.1.val - 2 = zeta_i_0.val - 2 * (k.val + 1). + show zi3.val = zeta_i_0.val - 2 * s.val + rw [h_zi3_val_arith, h_zeta_acc, hs_val] + have h_k_le_15 : k.val ≤ 15 := by omega + omega + · -- All j < s.val are FC-equal. + intro j hj + rw [hs_val] at hj + show lift_chunk ((acc.2.coefficients.set k t1).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: unchanged by set; use h_acc_done. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k + · -- j = k.val: it's t1; use h_t1_lift + h_t_eq + zeta_lift identities. + subst hj_eq_k + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq_val, h_t1_lift, h_t_eq] + have h_k_le_15 : k.val ≤ 15 := by omega + have h_zi1_z : zi1.val = zeta_i_0.val - 2 * k.val - 1 := by + rw [h_zi1_val_arith, h_zeta_acc] + have h_zi3_z : zi3.val = zeta_i_0.val - 2 * k.val - 2 := by + rw [h_zi3_val_arith, h_zeta_acc] + rw [show lift_fe_mont z1 = Spec.zeta_at (zeta_i_0.val - 2 * k.val - 1) + from by rw [← h_zi1_z]; exact h_z1_lift] + rw [show lift_fe_mont z2 = Spec.zeta_at (zeta_i_0.val - 2 * k.val - 2) + from by rw [← h_zi3_z]; exact h_z2_lift] + · -- All j ≥ s.val are unchanged. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- Per-lane output bound on every chunk. + intro c hc ℓ hℓ + show ((acc'.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328 + show (((acc.2.coefficients.set k t1).val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328 + by_cases h_ck : c = k.val + · have h_set_eq_val : + (acc.2.coefficients.set k t1).val[c]! = t1 := by + rw [h_ck] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq_val]; exact h_t1_bnd ℓ hℓ + · have h_ne : k.val ≠ c := Ne.symm h_ck + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[c]! = acc.2.coefficients.val[c]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k c t1 h_ne + rw [h_set_ne_val]; exact h_acc_bnd c hc ℓ hℓ + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer2FC.step_post zeta_i_0 re k (.done acc) + unfold Layer2FC.step_post + show (Layer2FC.inv zeta_i_0 re 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + acc.1.val = zeta_i_0.val - 2 * (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_2_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val - 2 * j - 1)) + (Spec.zeta_at (zeta_i_0.val - 2 * j - 2))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · rw [h_zeta_acc, hk_eq, h16] + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + · exact h_acc_bnd + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L3i.2 — `invert_ntt_at_layer_2` driver: 16-chunk loop, per-chunk + 2-zeta-lookup decreasing `zeta_i` by 2. Mirror of `L3i.1` and forward + `ntt_at_layer_2_portable_fc`. Locked POST exposes + `p.1.val = zeta_i.val - 32` (output zeta_i for composer chaining) + + `lift_poly p.2 = Spec.invert_ntt_layer_2_pure (lift_poly re) zeta_i`. + + Tightening preconditions (added by the proof author): + - `h_zeta_lo : 32 ≤ zeta_i.val` — needed for the 2 subtractions per chunk + to succeed without Nat underflow at every iter (worst case at iter k=15 + is `acc.1.val = zeta_i - 30`, and we further subtract 2). + - `h_zeta_hi : zeta_i.val ≤ 128` — so `polynomial.zeta` index (worst case + `zeta_i - 1` at iter k=0) is `< 128`. + - `h_bnd` per-lane bound `≤ 13312` on `re`'s chunks — matches + `inv_ntt_layer_2_step_fc`'s precondition. + - `h_bnd_strict` per-lane bound `≤ 3328` on `re`'s chunks — needed for + the strengthened POST's output-bound conjunct. -/ +@[spec high] +theorem invert_ntt_at_layer_2_portable_fc + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312) + (h_bnd_strict : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 3328) + (h_zeta_lo : 32 ≤ zeta_i.val) + (h_zeta_hi : zeta_i.val ≤ 128) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2 + (vectortraitsOperationsInst := portable_ops_inst) zeta_i re + ⦃ ⇓ p => ⌜ p.1.val = zeta_i.val - 32 + ∧ lift_poly p.2 = Spec.invert_ntt_layer_2_pure (lift_poly re) zeta_i + ∧ (∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2 + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_2_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + iter1 acc1.1 acc1.2) + (β := Layer2FC.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer2FC.inv zeta_i re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_, ?_⟩ + · -- zeta-thread invariant at k=0: zeta_i = zeta_i - 2*0. + show zeta_i.val = zeta_i.val - 2 * (0#usize : Std.Usize).val + show zeta_i.val = zeta_i.val - 2 * 0 + omega + · intro j hj + exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _ + trivial + · -- Initial bound: acc.2 = re at k=0, so bound from h_bnd_strict. + intro c hc ℓ hℓ + exact h_bnd_strict c hc ℓ hℓ) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations + zeta_i = zeta_i_0 - 32. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Layer2FC.inv zeta_i re 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + r.1.val = zeta_i.val - 2 * (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_2_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i.val - 2 * j - 1)) + (Spec.zeta_at (zeta_i.val - 2 * j - 2))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((r.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer2FC.inv] using h_inv_holds + obtain ⟨h_zeta_eq, h_done, _h_undone, h_done_bnd⟩ := h_inv + have h16 : (16#usize : Std.Usize).val = 16 := rfl + refine ⟨?_, ?_, ?_⟩ + · -- p.1.val = zeta_i.val - 32. + rw [h_zeta_eq, h16] + · -- The chunks equation. + unfold Spec.invert_ntt_layer_2_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_inv_ntt_layer_2_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val - 2 * k - 1)) + (Spec.zeta_at (zeta_i.val - 2 * k - 2)))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.2.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_inv_ntt_layer_2_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val - 2 * k - 1)) + (Spec.zeta_at (zeta_i.val - 2 * k - 2))))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc re k hk] + exact (h_done k hk).symm + have h_final := flatten_chunks_eq_lift_poly_fc r.2 chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Per-lane output bound from the loop invariant's strengthened conjunct. + exact h_done_bnd + · -- Step lemma application: dispatch invert_ntt_at_layer_2_step_lemma_fc. + intro acc k _h_ge h_le hinv + have h_step := invert_ntt_at_layer_2_step_lemma_fc zeta_i re h_bnd h_zeta_lo h_zeta_hi + acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer2FC.step_post zeta_i re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer2FC.step_post] using hP + · have hP : Layer2FC.step_post zeta_i re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer2FC.step_post] using hP + +/-! ### L3i.3 — Loop scaffolding for `invert_ntt_at_layer_3_portable_fc`. + + Mirror of §L3i.2 scaffolding simplified to **1 zeta per chunk** (instead of 2) + with `zeta_i` DECREASING by 1 per iter and reads in reverse order at + `zeta_i - k - 1`. -/ + +namespace Layer3FC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Local `usize_sub_ok_eq` helper (mirror of `Layer2FC.usize_sub_ok_eq`). -/ +theorem usize_sub_ok_eq (x y : Std.Usize) + (h_ge : y.val ≤ x.val) : + ∃ z : Std.Usize, (x - y : Result Std.Usize) = .ok z ∧ z.val = x.val - y.val := by + have hT := Std.Usize.sub_spec h_ge + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v.1⟩ + +/-- Step-local accumulator. -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `invert_ntt_at_layer_3_portable_fc`. + Tracks DECREASING `zeta_i`: at outer iter `k`, `acc.1.val = zeta_i_0.val - k.val`. + Chunks `< k.val` are FC-equal to the per-chunk inverse step; chunks `≥ k.val` + are unchanged from `re`. Per-lane output bound `≤ 3328` on every chunk. -/ +def inv + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = zeta_i_0.val - k.val + ∧ (∀ j : Nat, j < k.val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_3_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val - j - 1))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv zeta_i_0 re iter'.start acc').holds + | .done y => (inv zeta_i_0 re 16#usize y).holds + +end Layer3FC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the inverse layer-3 driver. Given a valid + loop state `(acc, k)` with `k.val < 16`, decreases `zeta_i` by 1 and records + the FC equation for chunk `k.val`, leaving chunks `> k.val` unchanged. -/ +theorem invert_ntt_at_layer_3_step_lemma_fc + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 13312) + (h_zeta_lo : 16 ≤ zeta_i_0.val) + (h_zeta_hi : zeta_i_0.val ≤ 128) + (acc : Layer3FC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (Layer3FC.inv zeta_i_0 re k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer3FC.step_post zeta_i_0 re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone, h_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some round = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + have h_um : (1#usize : Std.Usize).val = 1 := rfl + -- acc.1.val = zeta_i_0.val - k.val, with k.val ≤ 15 ⇒ acc.1.val ≥ zeta_i_0.val - 15 ≥ 1. + have h_acc1_ge_1 : 1 ≤ acc.1.val := by + rw [h_zeta_acc] + have h_k_le_15 : k.val ≤ 15 := by omega + omega + -- (1) `zeta_i - 1` ⇒ `zi1` with `zi1.val = acc.1.val - 1`. + have h_z_ge : (1#usize : Std.Usize).val ≤ acc.1.val := by rw [h_um]; omega + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + Layer3FC.usize_sub_ok_eq acc.1 1#usize h_z_ge + have h_zi1_val_arith : zi1.val = acc.1.val - 1 := by rw [h_zi1_val, h_um] + -- zi1.val < 128: zi1.val = acc.1.val - 1 = zeta_i_0 - k - 1 ≤ zeta_i_0 - 1 ≤ 127. + have h_zi1_lt : zi1.val < 128 := by + rw [h_zi1_val_arith, h_zeta_acc]; omega + -- (2) `index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx]; rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + -- (3) `polynomial.zeta zi1` ⇒ `z1` at index `zi1.val = acc.1.val - 1`. + obtain ⟨z1, h_z1_eq, h_z1_v, h_z1_bd, h_z1_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi1 h_zi1_lt) + -- (4) `inv_ntt_layer_3_step t z1`. Pre: t's lanes ≤ 13312 via h_pre + undone. + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 13312 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + -- @[reducible] portable_ops_inst forwards to vector.portable.ntt.inv_ntt_layer_3_step. + obtain ⟨t1, h_t1_eq, h_t1_lift, h_t1_bnd⟩ := + triple_exists_ok_fc (inv_ntt_layer_3_step_fc t z1 h_z1_bd h_t_bd) + -- Compose entire body. Loop output for `cont` is `(iter', zi1, re')`. + set acc' : Layer3FC.Acc := (zi1, { coefficients := acc.2.coefficients.set k t1 }) + with hacc'_def + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [Aeneas.Std.bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_3_step t z1 + Result.ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi1, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq]; rfl + apply triple_of_ok_fc h_body + show Layer3FC.step_post zeta_i_0 re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer3FC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (Layer3FC.inv zeta_i_0 re s acc').holds + have h_inv_pure : + acc'.1.val = zeta_i_0.val - s.val + ∧ (∀ j : Nat, j < s.val → + lift_chunk (acc'.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_3_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val - j - 1))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc'.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · -- acc'.1 = zi1, zi1.val = acc.1.val - 1 = zeta_i_0.val - (k.val + 1). + show zi1.val = zeta_i_0.val - s.val + rw [h_zi1_val_arith, h_zeta_acc, hs_val] + have h_k_le_15 : k.val ≤ 15 := by omega + omega + · -- All j < s.val are FC-equal. + intro j hj + rw [hs_val] at hj + show lift_chunk ((acc.2.coefficients.set k t1).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: unchanged by set; use h_acc_done. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k + · -- j = k.val: it's t1; use h_t1_lift + h_t_eq + zeta_lift identity. + subst hj_eq_k + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq_val, h_t1_lift, h_t_eq] + have h_k_le_15 : k.val ≤ 15 := by omega + have h_zi1_z : zi1.val = zeta_i_0.val - k.val - 1 := by + rw [h_zi1_val_arith, h_zeta_acc] + rw [show lift_fe_mont z1 = Spec.zeta_at (zeta_i_0.val - k.val - 1) + from by rw [← h_zi1_z]; exact h_z1_lift] + · -- All j ≥ s.val are unchanged. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- Per-lane output bound on every chunk. + intro c hc ℓ hℓ + show ((acc'.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328 + show (((acc.2.coefficients.set k t1).val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328 + by_cases h_ck : c = k.val + · have h_set_eq_val : + (acc.2.coefficients.set k t1).val[c]! = t1 := by + rw [h_ck] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq_val]; exact h_t1_bnd ℓ hℓ + · have h_ne : k.val ≠ c := Ne.symm h_ck + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[c]! = acc.2.coefficients.val[c]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k c t1 h_ne + rw [h_set_ne_val]; exact h_acc_bnd c hc ℓ hℓ + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer3FC.step_post zeta_i_0 re k (.done acc) + unfold Layer3FC.step_post + show (Layer3FC.inv zeta_i_0 re 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + acc.1.val = zeta_i_0.val - (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_3_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val - j - 1))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · rw [h_zeta_acc, hk_eq, h16] + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + · exact h_acc_bnd + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L3i.3 — `invert_ntt_at_layer_3` driver: 16-chunk loop, per-chunk + 1-zeta-lookup decreasing `zeta_i` by 1. Mirror of `L3i.2` and forward + `ntt_at_layer_3_portable_fc`. Locked POST exposes + `p.1.val = zeta_i.val - 16` (output zeta_i for composer chaining) + + `lift_poly p.2 = Spec.invert_ntt_layer_3_pure (lift_poly re) zeta_i`. + + Tightening preconditions (added by the proof author): + - `h_zeta_lo : 16 ≤ zeta_i.val` — needed for the 1 subtraction per chunk + to succeed without Nat underflow at every iter (worst case at iter k=15 + is `acc.1.val = zeta_i - 15`, and we further subtract 1). + - `h_zeta_hi : zeta_i.val ≤ 128` — so `polynomial.zeta` index (worst case + `zeta_i - 1` at iter k=0) is `< 128`. + - `h_bnd` per-lane bound `≤ 13312` on `re`'s chunks — matches + `inv_ntt_layer_3_step_fc`'s precondition. + - `h_bnd_strict` per-lane bound `≤ 3328` on `re`'s chunks — needed for + the strengthened POST's output-bound conjunct. -/ +@[spec high] +theorem invert_ntt_at_layer_3_portable_fc + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312) + (h_bnd_strict : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 3328) + (h_zeta_lo : 16 ≤ zeta_i.val) + (h_zeta_hi : zeta_i.val ≤ 128) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3 + (vectortraitsOperationsInst := portable_ops_inst) zeta_i re + ⦃ ⇓ p => ⌜ p.1.val = zeta_i.val - 16 + ∧ lift_poly p.2 = Spec.invert_ntt_layer_3_pure (lift_poly re) zeta_i + ∧ (∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3 + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_3_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + iter1 acc1.1 acc1.2) + (β := Layer3FC.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer3FC.inv zeta_i re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_, ?_⟩ + · -- zeta-thread invariant at k=0: zeta_i = zeta_i - 0. + show zeta_i.val = zeta_i.val - (0#usize : Std.Usize).val + show zeta_i.val = zeta_i.val - 0 + omega + · intro j hj + exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _ + trivial + · -- Initial bound: acc.2 = re at k=0, so bound from h_bnd_strict. + intro c hc ℓ hℓ + exact h_bnd_strict c hc ℓ hℓ) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations + zeta_i = zeta_i_0 - 16. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Layer3FC.inv zeta_i re 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + r.1.val = zeta_i.val - (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.2.coefficients.val[j]!) + = Spec.chunk_inv_ntt_layer_3_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i.val - j - 1))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.2.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((r.2.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer3FC.inv] using h_inv_holds + obtain ⟨h_zeta_eq, h_done, _h_undone, h_done_bnd⟩ := h_inv + have h16 : (16#usize : Std.Usize).val = 16 := rfl + refine ⟨?_, ?_, ?_⟩ + · -- p.1.val = zeta_i.val - 16. + rw [h_zeta_eq, h16] + · -- The chunks equation. + unfold Spec.invert_ntt_layer_3_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_inv_ntt_layer_3_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val - k - 1)))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.2.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_inv_ntt_layer_3_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val - k - 1))))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc re k hk] + exact (h_done k hk).symm + have h_final := flatten_chunks_eq_lift_poly_fc r.2 chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Per-lane output bound from the loop invariant's strengthened conjunct. + exact h_done_bnd + · -- Step lemma application: dispatch invert_ntt_at_layer_3_step_lemma_fc. + intro acc k _h_ge h_le hinv + have h_step := invert_ntt_at_layer_3_step_lemma_fc zeta_i re h_bnd h_zeta_lo h_zeta_hi + acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer3FC.step_post zeta_i re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer3FC.step_post] using hP + · have hP : Layer3FC.step_post zeta_i re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer3FC.step_post] using hP + +/-! ### L3i.4 — `inv_ntt_layer_int_vec_step_reduce` helper FC. + + Cross-chunk INVERSE NTT (Gentleman-Sande) butterfly between coefficient + chunks at positions `a` and `b` (where `b = a + step_vec`). Mirrors the + impl `invert_ntt.inv_ntt_layer_int_vec_step_reduce`: + + ``` + scratch1 := coefs[a]; t := coefs[b] + scratch3 := barrett(scratch1 + t) -- new coefs[a] = canonical(a + b) + coefs1[a] := scratch3 + scratch4 := −scratch3 + scratch7 := mont(scratch4 + 2*t) zeta_r -- new coefs[b] = (b − a) * z + coefs2[b] := scratch7 + ``` + + Used by the layer-4+ driver. The FC theorem exposes: + 1. lift_chunk equation on coefs[a] via `chunk_inv_pair_butterfly_a_pure`. + 2. lift_chunk equation on coefs[b] via `chunk_inv_pair_butterfly_b_pure`. + 3. Unchanged-chunk preservation for c ≠ a, c ≠ b. + 4. Output bound on both touched chunks (≤ 3328 since both go through + barrett/mont reduction). -/ +set_option maxHeartbeats 16000000 in +@[spec] +theorem inv_ntt_layer_int_vec_step_reduce_fc + (coefficients : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + 16#usize) + (a b : Std.Usize) (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_r : Std.I16) + (h_a : a.val < 16) (h_b : b.val < 16) (h_ne : a.val ≠ b.val) + (hzeta : zeta_r.val.natAbs ≤ 1664) + (h_chunk_a : ∀ k : Nat, k < 16 → + ((coefficients.val[a.val]!).elements.val[k]!).val.natAbs ≤ 3328) + (h_chunk_b : ∀ k : Nat, k < 16 → + ((coefficients.val[b.val]!).elements.val[k]!).val.natAbs ≤ 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.inv_ntt_layer_int_vec_step_reduce + portable_ops_inst coefficients a b scratch zeta_r + ⦃ ⇓ p => ⌜ lift_chunk (p.1.val[a.val]!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (coefficients.val[a.val]!)) + (lift_chunk (coefficients.val[b.val]!)) + ∧ lift_chunk (p.1.val[b.val]!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (coefficients.val[a.val]!)) + (lift_chunk (coefficients.val[b.val]!)) + (lift_fe_mont zeta_r) + ∧ (∀ c : Nat, c < 16 → c ≠ a.val → c ≠ b.val → + p.1.val[c]! = coefficients.val[c]!) + ∧ (∀ k : Nat, k < 16 → + ((p.1.val[a.val]!).elements.val[k]!).val.natAbs ≤ 3328) + ∧ (∀ k : Nat, k < 16 → + ((p.1.val[b.val]!).elements.val[k]!).val.natAbs ≤ 3328) ⌝ ⦄ := by + -- Setup: lengths. + have h_coef_len : coefficients.length = 16 := Std.Array.length_eq _ + -- Bind shorthand for the two source chunks. + set chunk_a : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + coefficients.val[a.val]! with hca_def + set chunk_b : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + coefficients.val[b.val]! with hcb_def + have h_chunk_a_len : chunk_a.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length chunk_a + have h_chunk_b_len : chunk_b.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length chunk_b + unfold libcrux_iot_ml_kem.invert_ntt.inv_ntt_layer_int_vec_step_reduce + -- (1) Read scratch1 = coefs[a]. + have h_idx_a : Aeneas.Std.Array.index_usize coefficients a = .ok chunk_a := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq coefficients a + (by rw [h_coef_len]; exact h_a) + -- (2) Read t = coefs[b]. + have h_idx_b : Aeneas.Std.Array.index_usize coefficients b = .ok chunk_b := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq coefficients b + (by rw [h_coef_len]; exact h_b) + -- (3) scratch2 = add(chunk_a, chunk_b). Pre: |a[ℓ] + b[ℓ]| ≤ 6656 < 32767 ✓. + have h_add_pre1 : ∀ ℓ : Nat, ℓ < 16 → + ((chunk_a.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val : Int).natAbs + ≤ 2^15 - 1 := by + intro ℓ hℓ + have hba := h_chunk_a ℓ hℓ + have hbb := h_chunk_b ℓ hℓ + have h_tri : ((chunk_a.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val : Int).natAbs + ≤ ((chunk_a.elements.val[ℓ]!).val : Int).natAbs + + ((chunk_b.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2]; omega + obtain ⟨scratch2, h_s2_eq, _h_s2_lift⟩ := + triple_exists_ok_fc (add_fc chunk_a chunk_b h_add_pre1) + have h_s2_legacy := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.add_spec chunk_a chunk_b h_add_pre1 + obtain ⟨scratch2', h_s2_eq', h_s2_per⟩ := triple_exists_ok_fc h_s2_legacy + have h_s2_same : scratch2 = scratch2' := by + have := h_s2_eq.symm.trans h_s2_eq'; cases this; rfl + subst h_s2_same + have h_s2_val : ∀ ℓ : Nat, ℓ < 16 → + (scratch2.elements.val[ℓ]!).val + = (chunk_a.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val := by + intro ℓ hℓ; exact (h_s2_per ℓ hℓ).1 + have h_s2_bnd : ∀ ℓ : Nat, ℓ < 16 → + (scratch2.elements.val[ℓ]!).val.natAbs ≤ 6656 := by + intro ℓ hℓ + rw [h_s2_val ℓ hℓ] + have hba := h_chunk_a ℓ hℓ + have hbb := h_chunk_b ℓ hℓ + have h_tri : ((chunk_a.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val : Int).natAbs + ≤ ((chunk_a.elements.val[ℓ]!).val : Int).natAbs + + ((chunk_b.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + omega + -- (4) scratch3 = barrett(scratch2). Pre: |scratch2[ℓ]| ≤ 32767 ✓. + have h_barrett_pre : ∀ ℓ : Nat, ℓ < 16 → + (scratch2.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + have := h_s2_bnd ℓ hℓ; omega + obtain ⟨scratch3, h_s3_eq, h_s3_bnd, _h_s3_lift⟩ := + triple_exists_ok_fc (barrett_reduce_fc scratch2 h_barrett_pre) + have h_s3_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.barrett_reduce_spec scratch2 h_barrett_pre + obtain ⟨scratch3', h_s3_eq', h_s3_per⟩ := triple_exists_ok_fc h_s3_legacy + have h_s3_same : scratch3 = scratch3' := by + have := h_s3_eq.symm.trans h_s3_eq'; cases this; rfl + subst h_s3_same + have h_s3_modq : ∀ ℓ : Nat, ℓ < 16 → + libcrux_iot_ml_kem.Spec.ModularArith.modq_eq (scratch3.elements.val[ℓ]!).val + (scratch2.elements.val[ℓ]!).val 3329 := + fun ℓ hℓ => (h_s3_per ℓ hℓ).1 + -- (5) coefficients1 = coefficients.set a scratch3. + set c1 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + coefficients.set a scratch3 with hc1_def + have h_upd_a : Aeneas.Std.Array.update coefficients a scratch3 = .ok c1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq coefficients a scratch3 + (by rw [h_coef_len]; exact h_a) + have h_c1_len : c1.length = 16 := by simp [hc1_def, h_coef_len] + -- (6) scratch4 = negate(scratch3). Pre: |scratch3[ℓ]| ≤ 3328 ≤ 2^15-1 ✓. + have h_neg_pre : ∀ ℓ : Nat, ℓ < 16 → + (scratch3.elements.val[ℓ]!).val.natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb := h_s3_bnd ℓ hℓ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2]; omega + obtain ⟨scratch4, h_s4_eq, _h_s4_lift⟩ := + triple_exists_ok_fc (negate_fc scratch3 h_neg_pre) + have h_s4_legacy := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.negate_spec scratch3 + obtain ⟨scratch4', h_s4_eq', h_s4_per⟩ := triple_exists_ok_fc h_s4_legacy + have h_s4_same : scratch4 = scratch4' := by + have := h_s4_eq.symm.trans h_s4_eq'; cases this; rfl + subst h_s4_same + -- Convert per-lane BV equality to value equality via the same dance as `negate_fc`. + have h_s4_val : ∀ ℓ : Nat, ℓ < 16 → + (scratch4.elements.val[ℓ]!).val = -(scratch3.elements.val[ℓ]!).val := by + intro ℓ hℓ + set xi : Std.I16 := scratch3.elements.val[ℓ]! with hxi + set ri : Std.I16 := scratch4.elements.val[ℓ]! with hri + have h_bv : ri.bv = -xi.bv := h_s4_per ℓ hℓ + have h_wsub_bv : + (Aeneas.Std.I16.wrapping_sub (0#i16) xi).bv = -xi.bv := by + rw [Aeneas.Std.I16.wrapping_sub_bv_eq] + simp only [show (0#i16 : Std.I16).bv = (0 : BitVec 16) from rfl] + exact BitVec.zero_sub xi.bv + have h_step1 : ri.val = (Aeneas.Std.I16.wrapping_sub (0#i16) xi).val := by + have h_toInt : (ri.bv).toInt + = (Aeneas.Std.I16.wrapping_sub (0#i16) xi).bv.toInt := by + rw [h_bv, h_wsub_bv] + have h_lhs : (ri.bv).toInt = ri.val := Aeneas.Std.I16.bv_toInt_eq ri + have h_rhs : (Aeneas.Std.I16.wrapping_sub (0#i16) xi).bv.toInt + = (Aeneas.Std.I16.wrapping_sub (0#i16) xi).val := + Aeneas.Std.I16.bv_toInt_eq _ + rw [h_lhs, h_rhs] at h_toInt + exact h_toInt + rw [h_step1, Aeneas.Std.I16.wrapping_sub_val_eq] + have h0 : (0#i16 : Std.I16).val = 0 := by decide + rw [h0] + have h_diff : (0 : Int) - xi.val = -xi.val := by ring + rw [h_diff] + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_abs : xi.val.natAbs ≤ 2^15 - 1 := h_neg_pre ℓ hℓ + have h_pow : -((2 : Int) ^ (16 - 1)) = -(2^15 : Int) := by decide + rw [h_pow]; omega + · have h_abs : xi.val.natAbs ≤ 2^15 - 1 := h_neg_pre ℓ hℓ + have h_pow : ((2 : Int) ^ (16 - 1)) = (2^15 : Int) := by decide + rw [h_pow]; omega + have h_s4_bnd : ∀ ℓ : Nat, ℓ < 16 → + (scratch4.elements.val[ℓ]!).val.natAbs ≤ 3328 := by + intro ℓ hℓ + rw [h_s4_val ℓ hℓ, Int.natAbs_neg]; exact h_s3_bnd ℓ hℓ + -- (7) t1 = c1[b] (= chunk_b since a ≠ b). + have h_c1_b : c1.val[b.val]! = chunk_b := by + show (coefficients.set a scratch3).val[b.val]! = chunk_b + have h_ne_ab : a.val ≠ b.val := h_ne + have h_step : (coefficients.set a scratch3).val[b.val]! = coefficients.val[b.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne coefficients a b.val scratch3 h_ne_ab + rw [h_step] + have h_idx_b1 : Aeneas.Std.Array.index_usize c1 b = .ok chunk_b := by + have h_idx : Aeneas.Std.Array.index_usize c1 b = .ok (c1.val[b.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq c1 b (by rw [h_c1_len]; exact h_b) + rw [h_idx, h_c1_b] + -- (8) scratch5 = add(scratch4, chunk_b). |scratch4| ≤ 3328, |chunk_b| ≤ 3328, sum ≤ 6656 ✓. + have h_add_pre2 : ∀ ℓ : Nat, ℓ < 16 → + ((scratch4.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val : Int).natAbs + ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb4 := h_s4_bnd ℓ hℓ + have hbb := h_chunk_b ℓ hℓ + have h_tri : ((scratch4.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val : Int).natAbs + ≤ ((scratch4.elements.val[ℓ]!).val : Int).natAbs + + ((chunk_b.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2]; omega + obtain ⟨scratch5, h_s5_eq, _h_s5_lift⟩ := + triple_exists_ok_fc (add_fc scratch4 chunk_b h_add_pre2) + have h_s5_legacy := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.add_spec scratch4 chunk_b h_add_pre2 + obtain ⟨scratch5', h_s5_eq', h_s5_per⟩ := triple_exists_ok_fc h_s5_legacy + have h_s5_same : scratch5 = scratch5' := by + have := h_s5_eq.symm.trans h_s5_eq'; cases this; rfl + subst h_s5_same + have h_s5_val : ∀ ℓ : Nat, ℓ < 16 → + (scratch5.elements.val[ℓ]!).val + = (scratch4.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val := by + intro ℓ hℓ; exact (h_s5_per ℓ hℓ).1 + have h_s5_bnd : ∀ ℓ : Nat, ℓ < 16 → + (scratch5.elements.val[ℓ]!).val.natAbs ≤ 6656 := by + intro ℓ hℓ + rw [h_s5_val ℓ hℓ] + have hb4 := h_s4_bnd ℓ hℓ + have hbb := h_chunk_b ℓ hℓ + have h_tri : ((scratch4.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val : Int).natAbs + ≤ ((scratch4.elements.val[ℓ]!).val : Int).natAbs + + ((chunk_b.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + omega + -- (9) scratch6 = add(scratch5, chunk_b). |scratch5| ≤ 6656, |chunk_b| ≤ 3328, sum ≤ 9984 ✓. + have h_add_pre3 : ∀ ℓ : Nat, ℓ < 16 → + ((scratch5.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val : Int).natAbs + ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb5 := h_s5_bnd ℓ hℓ + have hbb := h_chunk_b ℓ hℓ + have h_tri : ((scratch5.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val : Int).natAbs + ≤ ((scratch5.elements.val[ℓ]!).val : Int).natAbs + + ((chunk_b.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2]; omega + obtain ⟨scratch6, h_s6_eq, _h_s6_lift⟩ := + triple_exists_ok_fc (add_fc scratch5 chunk_b h_add_pre3) + have h_s6_legacy := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.add_spec scratch5 chunk_b h_add_pre3 + obtain ⟨scratch6', h_s6_eq', h_s6_per⟩ := triple_exists_ok_fc h_s6_legacy + have h_s6_same : scratch6 = scratch6' := by + have := h_s6_eq.symm.trans h_s6_eq'; cases this; rfl + subst h_s6_same + have h_s6_val : ∀ ℓ : Nat, ℓ < 16 → + (scratch6.elements.val[ℓ]!).val + = (scratch5.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val := by + intro ℓ hℓ; exact (h_s6_per ℓ hℓ).1 + have h_s6_bnd : ∀ ℓ : Nat, ℓ < 16 → + (scratch6.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + rw [h_s6_val ℓ hℓ] + have hb5 := h_s5_bnd ℓ hℓ + have hbb := h_chunk_b ℓ hℓ + have h_tri : ((scratch5.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val : Int).natAbs + ≤ ((scratch5.elements.val[ℓ]!).val : Int).natAbs + + ((chunk_b.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + omega + -- (10) classify zeta_r = zeta_r (Public->Secret blanket identity on I16). + have h_classify_zeta : + libcrux_secrets.traits.Classify.Blanket.classify zeta_r = .ok zeta_r := + ntt_step_fc.classify_ok_eq zeta_r + -- (11) scratch7 = mont_mult_by_const(scratch6, zeta_r). Pre: |scratch6| ≤ 32767, |zeta_r| ≤ 1664. + obtain ⟨scratch7, h_s7_eq, _h_s7_lift⟩ := + triple_exists_ok_fc (montgomery_multiply_by_constant_fc scratch6 zeta_r h_s6_bnd hzeta) + have h_s7_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.montgomery_multiply_by_constant_spec scratch6 zeta_r hzeta + obtain ⟨scratch7', h_s7_eq', h_s7_per⟩ := triple_exists_ok_fc h_s7_legacy + have h_s7_same : scratch7 = scratch7' := by + have := h_s7_eq.symm.trans h_s7_eq'; cases this; rfl + subst h_s7_same + have h_s7_modq : ∀ ℓ : Nat, ℓ < 16 → + ((scratch7.elements.val[ℓ]!).val * (2 ^ 16 : Int)) % 3329 + = ((scratch6.elements.val[ℓ]!).val * zeta_r.val) % 3329 := + fun ℓ hℓ => (h_s7_per ℓ hℓ).2 + -- (12) coefficients2 = c1.set b scratch7. + set c2 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + c1.set b scratch7 with hc2_def + have h_upd_b : Aeneas.Std.Array.update c1 b scratch7 = .ok c2 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq c1 b scratch7 + (by rw [h_c1_len]; exact h_b) + -- Compose the body equation. + have h_body : + libcrux_iot_ml_kem.invert_ntt.inv_ntt_layer_int_vec_step_reduce + portable_ops_inst coefficients a b scratch zeta_r + = .ok (c2, scratch7) := by + show (do + let scratch1' ← Aeneas.Std.Array.index_usize coefficients a + let t' ← Aeneas.Std.Array.index_usize coefficients b + let scratch2' ← portable_ops_inst.add scratch1' t' + let scratch3' ← portable_ops_inst.barrett_reduce scratch2' + let coefficients1' ← Aeneas.Std.Array.update coefficients a scratch3' + let scratch4' ← portable_ops_inst.negate scratch3' + let t1' ← Aeneas.Std.Array.index_usize coefficients1' b + let scratch5' ← portable_ops_inst.add scratch4' t1' + let scratch6' ← portable_ops_inst.add scratch5' t1' + let scratch7' ← + libcrux_iot_ml_kem.vector.traits.montgomery_multiply_fe + portable_ops_inst scratch6' zeta_r + let coefficients2' ← Aeneas.Std.Array.update coefficients1' b scratch7' + .ok (coefficients2', scratch7')) = _ + -- Trait method calls reduce to vector.portable.arithmetic.* via reducibility. + show (do + let scratch1' ← Aeneas.Std.Array.index_usize coefficients a + let t' ← Aeneas.Std.Array.index_usize coefficients b + let scratch2' ← libcrux_iot_ml_kem.vector.portable.arithmetic.add scratch1' t' + let scratch3' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce scratch2' + let coefficients1' ← Aeneas.Std.Array.update coefficients a scratch3' + let scratch4' ← libcrux_iot_ml_kem.vector.portable.arithmetic.negate scratch3' + let t1' ← Aeneas.Std.Array.index_usize coefficients1' b + let scratch5' ← libcrux_iot_ml_kem.vector.portable.arithmetic.add scratch4' t1' + let scratch6' ← libcrux_iot_ml_kem.vector.portable.arithmetic.add scratch5' t1' + let scratch7' ← do + let i ← libcrux_secrets.traits.Classify.Blanket.classify zeta_r + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant + scratch6' i + let coefficients2' ← Aeneas.Std.Array.update coefficients1' b scratch7' + .ok (coefficients2', scratch7')) = _ + rw [h_idx_a]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_b]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_s2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_s3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_a]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_s4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_b1]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_s5_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_s6_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_classify_zeta]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_s7_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_b] + simp only [Aeneas.Std.bind_tc_ok] + apply triple_of_ok_fc h_body + -- Now prove the 5-conjunct post. + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + · -- (a) lift_chunk c2[a] = chunk_inv_pair_butterfly_a_pure (lift_chunk chunk_a) (lift_chunk chunk_b). + -- c2 = c1.set b scratch7; at index a, since a ≠ b, c2[a] = c1[a] = scratch3. + show lift_chunk (c2.val[a.val]!) = _ + have h_ne_ba : b.val ≠ a.val := fun h => h_ne h.symm + have h_c2_a : c2.val[a.val]! = scratch3 := by + show (c1.set b scratch7).val[a.val]! = scratch3 + have h_step1 : (c1.set b scratch7).val[a.val]! = c1.val[a.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne c1 b a.val scratch7 h_ne_ba + have h_step2 : c1.val[a.val]! = scratch3 := by + show (coefficients.set a scratch3).val[a.val]! = scratch3 + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq coefficients a a.val scratch3 + ⟨rfl, by rw [h_coef_len]; exact h_a⟩ + rw [h_step1, h_step2] + rw [h_c2_a] + -- Goal: lift_chunk scratch3 = chunk_inv_pair_butterfly_a_pure (lift_chunk chunk_a) (lift_chunk chunk_b). + -- Per-lane: lift_fe scratch3[ℓ] = add_pure (lift_fe chunk_a[ℓ]) (lift_fe chunk_b[ℓ]). + -- We have h_s3_modq : modq_eq scratch3[ℓ].val scratch2[ℓ].val 3329, and + -- h_s2_val : scratch2[ℓ].val = chunk_a[ℓ].val + chunk_b[ℓ].val. + have h_s3_lane_modq : ∀ ℓ : Nat, ℓ < 16 → + libcrux_iot_ml_kem.Spec.ModularArith.modq_eq (scratch3.elements.val[ℓ]!).val + ((chunk_a.elements.val[ℓ]!).val + (chunk_b.elements.val[ℓ]!).val) 3329 := by + intro ℓ hℓ + have h_m := h_s3_modq ℓ hℓ + have h_v := h_s2_val ℓ hℓ + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq at h_m ⊢ + rw [← h_v]; exact h_m + -- Now unfold and prove lane-by-lane. + unfold lift_chunk Spec.chunk_inv_pair_butterfly_a_pure + apply Subtype.ext + have h_s3_len : scratch3.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length scratch3 + show scratch3.elements.val.map lift_fe + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Std.Array.make 16#usize (chunk_a.elements.val.map lift_fe) + (by simp)).val[ℓ]!) + ((Std.Array.make 16#usize (chunk_b.elements.val.map lift_fe) + (by simp)).val[ℓ]!)) + apply List.ext_getElem + · simp [List.length_map, List.length_range, h_s3_len] + · intro ℓ hℓ1 _ + have hℓ : ℓ < 16 := by + have : ℓ < (scratch3.elements.val.map lift_fe).length := hℓ1 + simp [List.length_map, h_s3_len] at this; exact this + rw [List.getElem_map, List.getElem_map, List.getElem_range] + -- LHS: lift_fe scratch3.elements.val[ℓ] + -- RHS: add_pure (chunk_a.lift)[ℓ]! (chunk_b.lift)[ℓ]! + have h_s3_get : scratch3.elements.val[ℓ] = scratch3.elements.val[ℓ]! := by + have hi : ℓ < scratch3.elements.val.length := by rw [h_s3_len]; exact hℓ + rw [getElem!_pos scratch3.elements.val ℓ hi] + rw [h_s3_get] + have h_lift_a_idx : + (Std.Array.make 16#usize (chunk_a.elements.val.map lift_fe) + (by simp)).val[ℓ]! = lift_fe (chunk_a.elements.val[ℓ]!) := by + show (chunk_a.elements.val.map lift_fe)[ℓ]! = _ + have hL : (chunk_a.elements.val.map lift_fe).length = 16 := by + simp [List.length_map, h_chunk_a_len] + rw [getElem!_pos _ ℓ (by rw [hL]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos chunk_a.elements.val ℓ (by rw [h_chunk_a_len]; exact hℓ)] + have h_lift_b_idx : + (Std.Array.make 16#usize (chunk_b.elements.val.map lift_fe) + (by simp)).val[ℓ]! = lift_fe (chunk_b.elements.val[ℓ]!) := by + show (chunk_b.elements.val.map lift_fe)[ℓ]! = _ + have hL : (chunk_b.elements.val.map lift_fe).length = 16 := by + simp [List.length_map, h_chunk_b_len] + rw [getElem!_pos _ ℓ (by rw [hL]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos chunk_b.elements.val ℓ (by rw [h_chunk_b_len]; exact hℓ)] + rw [h_lift_a_idx, h_lift_b_idx] + -- Goal: lift_fe scratch3.elements.val[ℓ]! + -- = add_pure (lift_fe chunk_a.elements.val[ℓ]!) (lift_fe chunk_b.elements.val[ℓ]!). + -- We have h_s3_lane_modq ℓ hℓ : modq_eq scratch3[ℓ].val (a[ℓ].val + b[ℓ].val) 3329. + -- Manufacture a synthetic i16 r_a := wrapping_add chunk_a[ℓ] chunk_b[ℓ]; + -- since |a| + |b| ≤ 6656 ≤ 29439 + 3328, r_a.val = a.val + b.val (no overflow). + -- Then lift_fe scratch3[ℓ] = lift_fe r_a (via modq), and + -- lift_fe r_a = add_pure (lift_fe a[ℓ]) (lift_fe b[ℓ]) via lift_fe_add_pure_eq. + set xa : Std.I16 := chunk_a.elements.val[ℓ]! with hxa_def + set xb : Std.I16 := chunk_b.elements.val[ℓ]! with hxb_def + set ra : Std.I16 := Std.I16.wrapping_add xa xb with hra_def + have h_xa_bnd : xa.val.natAbs ≤ 3328 := h_chunk_a ℓ hℓ + have h_xb_bnd : xb.val.natAbs ≤ 3328 := h_chunk_b ℓ hℓ + have h_ra_val : ra.val = xa.val + xb.val := + ntt_step_fc.add_no_overflow_value xa xb 3328 h_xa_bnd h_xb_bnd (by decide) + have h_lift_ra : lift_fe ra + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe xa) (lift_fe xb) := + lift_fe_add_pure_eq xa xb ra h_ra_val + -- From h_s3_lane_modq combined with h_ra_val: modq_eq scratch3.val ra.val 3329. + have h_s3_ra_modq : + libcrux_iot_ml_kem.Spec.ModularArith.modq_eq (scratch3.elements.val[ℓ]!).val ra.val 3329 := by + have h_m := h_s3_lane_modq ℓ hℓ + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq at h_m ⊢ + rw [h_ra_val]; exact h_m + have h_lift_eq : lift_fe (scratch3.elements.val[ℓ]!) = lift_fe ra := + lift_fe_eq_of_modq _ _ h_s3_ra_modq + rw [h_lift_eq, h_lift_ra] + · -- (b) lift_chunk c2[b] = chunk_inv_pair_butterfly_b_pure (lift_chunk chunk_a) (lift_chunk chunk_b) (lift_fe_mont zeta_r). + -- c2[b] = (c1.set b scratch7)[b] = scratch7. + show lift_chunk (c2.val[b.val]!) = _ + have h_c2_b : c2.val[b.val]! = scratch7 := by + show (c1.set b scratch7).val[b.val]! = scratch7 + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq c1 b b.val scratch7 + ⟨rfl, by rw [h_c1_len]; exact h_b⟩ + rw [h_c2_b] + -- Goal: lift_chunk scratch7 = chunk_inv_pair_butterfly_b_pure ... + -- Per-lane: lift_fe_mont scratch7[ℓ]? NO -- chunk_inv_pair_butterfly_b_pure produces + -- `mul_pure (sub_pure b[ℓ] a[ℓ]) z` which is a PLAIN-domain FE, not a Mont FE. + -- So we need `lift_fe scratch7[ℓ] = mul_pure (sub_pure (lift_fe b) (lift_fe a)) (lift_fe_mont z)`. + -- The modq fact: scratch7[ℓ].val * 2^16 ≡ scratch6[ℓ].val * zeta_r.val (mod q). + -- We need to chain: scratch6[ℓ].val = (b[ℓ].val - a[ℓ].val) (mod q) [from the s3,s4,s5,s6 chain]. + -- Step (i): derive scratch6[ℓ].val ≡ b[ℓ].val - a[ℓ].val (mod q). + have h_s6_lane_modq : ∀ ℓ : Nat, ℓ < 16 → + libcrux_iot_ml_kem.Spec.ModularArith.modq_eq (scratch6.elements.val[ℓ]!).val + ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) 3329 := by + intro ℓ hℓ + -- scratch6[ℓ].val = scratch5[ℓ].val + chunk_b[ℓ].val + -- = (scratch4[ℓ].val + chunk_b[ℓ].val) + chunk_b[ℓ].val + -- = (-scratch3[ℓ].val + chunk_b[ℓ].val) + chunk_b[ℓ].val + -- = -scratch3[ℓ].val + 2*chunk_b[ℓ].val + -- scratch3[ℓ].val ≡ chunk_a[ℓ].val + chunk_b[ℓ].val (mod q) + -- so scratch6[ℓ].val ≡ -(a + b) + 2b = b - a (mod q). + have h_v6 := h_s6_val ℓ hℓ + have h_v5 := h_s5_val ℓ hℓ + have h_v4 := h_s4_val ℓ hℓ + have h_v3 := h_s3_modq ℓ hℓ + have h_v2 := h_s2_val ℓ hℓ + -- Combine: + have h_chain : (scratch6.elements.val[ℓ]!).val + = -(scratch3.elements.val[ℓ]!).val + 2 * (chunk_b.elements.val[ℓ]!).val := by + rw [h_v6, h_v5, h_v4]; ring + -- Now modq: -scratch3 ≡ -(a+b), 2b - (a+b) = b - a (mod q). + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq at h_v3 ⊢ + -- h_v3 : (scratch3.val - scratch2.val) % 3329 = 0. + -- Goal: (scratch6.val - (b.val - a.val)) % 3329 = 0. + -- scratch6.val - (b.val - a.val) = -scratch3.val + 2b - b + a + -- = -scratch3.val + a + b + -- = -(scratch3.val - (a + b)) + -- = -(scratch3.val - scratch2.val) [using h_v2] + have h_eq : (scratch6.elements.val[ℓ]!).val + - ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + = -((scratch3.elements.val[ℓ]!).val - (scratch2.elements.val[ℓ]!).val) := by + rw [h_chain, h_v2]; ring + rw [h_eq, Int.neg_emod] + omega + -- Step (ii): derive scratch7[ℓ].val ≡ scratch6[ℓ].val * zeta_r.val * 169 (mod q), + -- using 2^16 * 169 ≡ 1 (mod q). + have h_s7_lane_modq_pre : ∀ ℓ : Nat, ℓ < 16 → + libcrux_iot_ml_kem.Spec.ModularArith.modq_eq (scratch7.elements.val[ℓ]!).val + ((scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) 3329 := by + intro ℓ hℓ + have h_per := h_s7_modq ℓ hℓ + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + have h_169 : ((2^16 : Int) * 169) % 3329 = 1 := by decide + have h_rmul : ((scratch7.elements.val[ℓ]!).val * (2^16 : Int) * 169) % 3329 + = ((scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 := by + have h1 : ((scratch7.elements.val[ℓ]!).val * (2^16 : Int) * 169) % 3329 + = ((scratch7.elements.val[ℓ]!).val * (2^16 : Int)) % 3329 * 169 % 3329 := by + rw [Int.mul_emod]; simp + have h2 : ((scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 + = ((scratch6.elements.val[ℓ]!).val * zeta_r.val) % 3329 * 169 % 3329 := by + rw [Int.mul_emod]; simp + rw [h1, h2, h_per] + have h_lhs : ((scratch7.elements.val[ℓ]!).val * (2^16 : Int) * 169) % 3329 + = (scratch7.elements.val[ℓ]!).val % 3329 := by + have h_mul_assoc : (scratch7.elements.val[ℓ]!).val * (2^16 : Int) * 169 + = (scratch7.elements.val[ℓ]!).val * ((2^16 : Int) * 169) := by ring + rw [h_mul_assoc, Int.mul_emod, h_169]; simp + have h_zsub : + ((scratch7.elements.val[ℓ]!).val + - (scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 = 0 := by + have h_sub_emod : ((scratch7.elements.val[ℓ]!).val + - (scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 + = ((scratch7.elements.val[ℓ]!).val % 3329 + - ((scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329) % 3329 := by + rw [Int.sub_emod] + rw [h_sub_emod, ← h_lhs, h_rmul]; simp + exact h_zsub + -- Step (iii): combine: scratch7[ℓ].val ≡ (b - a) * zeta_r * 169 (mod q), + -- which is what `lift_fe_mul_pure_mont_eq` needs (the first arg is `chunk_b - chunk_a`, + -- in the role of the "a" of mul_pure_mont). + -- But the existing `lift_fe_mul_pure_mont_eq` takes a single `a : Std.I16` whose .val + -- is multiplied with `c`. We don't have a single i16 carrying `b - a`. We need an + -- intermediate bridge. Construct the desired equation manually. + -- The b-side bridge: from modq_eq scratch7.val ((b - a) * zeta_r * 169) 3329, + -- show lift_fe scratch7[ℓ] = mul_pure (sub_pure (lift_fe b[ℓ]) (lift_fe a[ℓ])) (lift_fe_mont zeta_r). + have h_s7_lane_modq : ∀ ℓ : Nat, ℓ < 16 → + libcrux_iot_ml_kem.Spec.ModularArith.modq_eq (scratch7.elements.val[ℓ]!).val + (((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + * zeta_r.val * 169) 3329 := by + intro ℓ hℓ + have hpre := h_s7_lane_modq_pre ℓ hℓ + have h6 := h_s6_lane_modq ℓ hℓ + -- Compose: scratch7 ≡ scratch6 * z * 169 ≡ (b-a) * z * 169 (mod q). + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq at hpre h6 ⊢ + -- h6 : (scratch6 - (b - a)) % 3329 = 0. + -- We want: (scratch7 - (b - a) * z * 169) % 3329 = 0. + -- We have hpre : (scratch7 - scratch6 * z * 169) % 3329 = 0. + -- And scratch6 % 3329 = (b - a) % 3329 (from h6). + -- So scratch6 * z * 169 ≡ (b - a) * z * 169 (mod q). + have h_scratch6_zmod : + (scratch6.elements.val[ℓ]!).val % 3329 + = ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) % 3329 := by + have h_dvd : (3329 : Int) ∣ ((scratch6.elements.val[ℓ]!).val + - ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val)) := + Int.dvd_of_emod_eq_zero h6 + have h_sub : (scratch6.elements.val[ℓ]!).val + - ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + = (scratch6.elements.val[ℓ]!).val + - ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) := rfl + omega + have h_mul_zmod : + ((scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 + = (((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + * zeta_r.val * 169) % 3329 := by + have h1 : ((scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 + = ((scratch6.elements.val[ℓ]!).val % 3329) * (zeta_r.val * 169 % 3329) % 3329 := by + conv_lhs => rw [show (scratch6.elements.val[ℓ]!).val * zeta_r.val * 169 + = (scratch6.elements.val[ℓ]!).val * (zeta_r.val * 169) from by ring] + rw [Int.mul_emod] + have h2 : (((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + * zeta_r.val * 169) % 3329 + = (((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) % 3329) + * (zeta_r.val * 169 % 3329) % 3329 := by + conv_lhs => rw [show ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + * zeta_r.val * 169 + = ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + * (zeta_r.val * 169) from by ring] + rw [Int.mul_emod] + rw [h1, h2, h_scratch6_zmod] + -- Now combine: scratch7 - (b - a)*z*169 ≡ scratch7 - scratch6*z*169 (mod q). + have h_link : + ((scratch7.elements.val[ℓ]!).val + - ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + * zeta_r.val * 169) % 3329 + = ((scratch7.elements.val[ℓ]!).val + - (scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 := by + have h_sub1 : ((scratch7.elements.val[ℓ]!).val + - ((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + * zeta_r.val * 169) % 3329 + = ((scratch7.elements.val[ℓ]!).val % 3329 + - (((chunk_b.elements.val[ℓ]!).val - (chunk_a.elements.val[ℓ]!).val) + * zeta_r.val * 169) % 3329) % 3329 := by rw [Int.sub_emod] + have h_sub2 : ((scratch7.elements.val[ℓ]!).val + - (scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 + = ((scratch7.elements.val[ℓ]!).val % 3329 + - ((scratch6.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329) % 3329 := by + rw [Int.sub_emod] + rw [h_sub1, h_sub2, h_mul_zmod] + rw [h_link]; exact hpre + -- Now reduce the chunk goal to per-lane. + unfold lift_chunk Spec.chunk_inv_pair_butterfly_b_pure + apply Subtype.ext + have h_s7_len : scratch7.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length scratch7 + show scratch7.elements.val.map lift_fe + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((Std.Array.make 16#usize (chunk_b.elements.val.map lift_fe) + (by simp)).val[ℓ]!) + ((Std.Array.make 16#usize (chunk_a.elements.val.map lift_fe) + (by simp)).val[ℓ]!)) + (lift_fe_mont zeta_r)) + apply List.ext_getElem + · simp [List.length_map, List.length_range, h_s7_len] + · intro ℓ hℓ1 _ + have hℓ : ℓ < 16 := by + have : ℓ < (scratch7.elements.val.map lift_fe).length := hℓ1 + simp [List.length_map, h_s7_len] at this; exact this + rw [List.getElem_map, List.getElem_map, List.getElem_range] + have h_s7_get : scratch7.elements.val[ℓ] = scratch7.elements.val[ℓ]! := by + have hi : ℓ < scratch7.elements.val.length := by rw [h_s7_len]; exact hℓ + rw [getElem!_pos scratch7.elements.val ℓ hi] + rw [h_s7_get] + have h_lift_a_idx : + (Std.Array.make 16#usize (chunk_a.elements.val.map lift_fe) + (by simp)).val[ℓ]! = lift_fe (chunk_a.elements.val[ℓ]!) := by + show (chunk_a.elements.val.map lift_fe)[ℓ]! = _ + have hL : (chunk_a.elements.val.map lift_fe).length = 16 := by + simp [List.length_map, h_chunk_a_len] + rw [getElem!_pos _ ℓ (by rw [hL]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos chunk_a.elements.val ℓ (by rw [h_chunk_a_len]; exact hℓ)] + have h_lift_b_idx : + (Std.Array.make 16#usize (chunk_b.elements.val.map lift_fe) + (by simp)).val[ℓ]! = lift_fe (chunk_b.elements.val[ℓ]!) := by + show (chunk_b.elements.val.map lift_fe)[ℓ]! = _ + have hL : (chunk_b.elements.val.map lift_fe).length = 16 := by + simp [List.length_map, h_chunk_b_len] + rw [getElem!_pos _ ℓ (by rw [hL]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos chunk_b.elements.val ℓ (by rw [h_chunk_b_len]; exact hℓ)] + rw [h_lift_a_idx, h_lift_b_idx] + -- Goal: lift_fe scratch7[ℓ]! + -- = mul_pure (sub_pure (lift_fe b[ℓ]) (lift_fe a[ℓ])) (lift_fe_mont zeta_r). + -- Manufacture rb := wrapping_sub chunk_b[ℓ] chunk_a[ℓ]; since |b| + |a| ≤ 6656, + -- rb.val = b.val - a.val (no overflow). Then: + -- lift_fe rb = sub_pure (lift_fe b) (lift_fe a) via lift_fe_sub_pure_eq. + -- The shape of `lift_fe_mul_pure_mont_eq` matches once we substitute rb for the LHS + -- of the multiplication. + set xa : Std.I16 := chunk_a.elements.val[ℓ]! with hxa_def + set xb : Std.I16 := chunk_b.elements.val[ℓ]! with hxb_def + set rb : Std.I16 := Std.I16.wrapping_sub xb xa with hrb_def + have h_xa_bnd : xa.val.natAbs ≤ 3328 := h_chunk_a ℓ hℓ + have h_xb_bnd : xb.val.natAbs ≤ 3328 := h_chunk_b ℓ hℓ + have h_rb_val : rb.val = xb.val - xa.val := + ntt_step_fc.sub_no_overflow_value xb xa 3328 h_xb_bnd h_xa_bnd (by decide) + -- lift_fe rb = sub_pure (lift_fe xb) (lift_fe xa). + have h_lift_rb : lift_fe rb + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe xb) (lift_fe xa) := + lift_fe_sub_pure_eq xb xa rb h_rb_val + -- Build the modq fact in terms of rb: scratch7.val ≡ rb.val * zeta_r.val * 169 (mod q). + have h_s7_rb_modq : + libcrux_iot_ml_kem.Spec.ModularArith.modq_eq (scratch7.elements.val[ℓ]!).val + (rb.val * zeta_r.val * 169) 3329 := by + have h_m := h_s7_lane_modq ℓ hℓ + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq at h_m ⊢ + rw [h_rb_val]; exact h_m + -- Apply the existing bridge `lift_fe_mul_pure_mont_eq` with first arg rb. + have h_lift_s7 : lift_fe (scratch7.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe rb) (lift_fe_mont zeta_r) := + lift_fe_mul_pure_mont_eq rb zeta_r (scratch7.elements.val[ℓ]!) h_s7_rb_modq + rw [h_lift_s7, h_lift_rb] + · -- (c) Preservation: for c ≠ a, c ≠ b: c2.val[c]! = coefficients.val[c]!. + intro c hc hca hcb + show c2.val[c]! = coefficients.val[c]! + have h_step1 : c2.val[c]! = c1.val[c]! := by + show (c1.set b scratch7).val[c]! = c1.val[c]! + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne c1 b c scratch7 (fun h => hcb h.symm) + have h_step2 : c1.val[c]! = coefficients.val[c]! := by + show (coefficients.set a scratch3).val[c]! = coefficients.val[c]! + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne coefficients a c scratch3 (fun h => hca h.symm) + rw [h_step1, h_step2] + · -- (d) Chunk a bound: c2[a] = scratch3 (Barrett-reduced) → ≤ 3328. + intro k hk + show ((c2.val[a.val]!).elements.val[k]!).val.natAbs ≤ 3328 + have h_ne_ba : b.val ≠ a.val := fun h => h_ne h.symm + have h_c2_a : c2.val[a.val]! = scratch3 := by + show (c1.set b scratch7).val[a.val]! = scratch3 + have h_step1 : (c1.set b scratch7).val[a.val]! = c1.val[a.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne c1 b a.val scratch7 h_ne_ba + have h_step2 : c1.val[a.val]! = scratch3 := by + show (coefficients.set a scratch3).val[a.val]! = scratch3 + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq coefficients a a.val scratch3 + ⟨rfl, by rw [h_coef_len]; exact h_a⟩ + rw [h_step1, h_step2] + rw [h_c2_a]; exact h_s3_bnd k hk + · -- (e) Chunk b bound: c2[b] = scratch7 (Mont-multiplied by zeta ≤ 1664) → ≤ 3328. + intro k hk + show ((c2.val[b.val]!).elements.val[k]!).val.natAbs ≤ 3328 + have h_c2_b : c2.val[b.val]! = scratch7 := by + show (c1.set b scratch7).val[b.val]! = scratch7 + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq c1 b b.val scratch7 + ⟨rfl, by rw [h_c1_len]; exact h_b⟩ + rw [h_c2_b]; exact (h_s7_per k hk).1 + +/-! ### L3i.5 — `invert_ntt_at_layer_4_plus` driver (Task H.1). + + Nested-loop driver for the inverse-NTT cross-chunk butterflies at + layers 4-7. Mirror of forward `ntt_at_layer_4_plus_portable_fc` with: + - INVERSE butterfly direction (uses `chunk_inv_pair_butterfly_{a,b}_pure` + instead of forward's `chunk_pair_butterfly_{a,b}_pure`). + - zeta_i DECREMENTS (zeta_fn group := `Spec.zeta_at (zeta_i - 1 - group)`). + - Dispatches to closed `inv_ntt_layer_int_vec_step_reduce_fc` + (Task H.0) for each inner butterfly. -/ + +namespace Layer4PlusInnerFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Inner loop accumulator: (re, scratch). -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector × + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- Inverse inner loop invariant (mirror of forward `Layer4PlusInnerFC.inv` + but with inverse butterflies and no `z` on a-side). -/ +def inv + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset : Std.Usize) + (zeta : hacspec_ml_kem.parameters.FieldElement) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j' : Nat, j' < k.val → + lift_chunk (acc.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!))) + ∧ (∀ j' : Nat, j' < k.val → + lift_chunk (acc.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + zeta) + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < k.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + acc.1.coefficients.val[k']! = re0.coefficients.val[k']!) + ∧ (∀ k' : Nat, k' < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[k']!).elements.val[ℓ]!).val.natAbs ≤ 3328)) + +/-- Step-post for the inverse inner loop. -/ +def step_post + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset step_vec : Std.Usize) + (zeta : hacspec_ml_kem.parameters.FieldElement) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < step_vec.val ∧ iter'.«end» = step_vec + ∧ iter'.start.val = k.val + 1 + ∧ (inv re0 a_offset b_offset zeta iter'.start acc').holds + | .done y => (inv re0 a_offset b_offset zeta step_vec y).holds + +end Layer4PlusInnerFC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the inverse inner loop. -/ +theorem invert_ntt_at_layer_4_plus_inner_step_lemma_fc + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset step_vec : Std.Usize) (zeta_i1 : Std.Usize) + (h_zi1_lt : zeta_i1.val < 128) + (h_step_vec_pos : 1 ≤ step_vec.val) + (h_a_offset_b : a_offset.val + step_vec.val ≤ 16) + (h_b_offset_b : b_offset.val + step_vec.val ≤ 16) + (h_disjoint : a_offset.val + step_vec.val ≤ b_offset.val) + (h_pre_a : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[a_offset.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 3328) + (h_pre_b : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[b_offset.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 3328) + (acc : Layer4PlusInnerFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ step_vec.val) + (h_inv : (Layer4PlusInnerFC.inv re0 a_offset b_offset + (Spec.zeta_at zeta_i1.val) k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i1 a_offset b_offset + { start := k, «end» := step_vec } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k r ⌝ ⦄ := by + have h_coef_len : acc.1.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_acc_a, h_acc_b, h_acc_undone, h_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0.body + by_cases h_lt : k.val < step_vec.val + · -- Some j = k branch. + obtain ⟨s, hs_val, h_iter_some⟩ := + Layer4PlusFC.iter_next_some_eq_gen k step_vec h_lt + -- (1) i ← a_offset + k. + have h_a_max : a_offset.val + k.val ≤ Std.Usize.max := by + have h_ab_b : a_offset.val + k.val ≤ 16 := by omega + scalar_tac + obtain ⟨i, h_i_eq, h_i_val⟩ := + Layer4PlusFC.usize_add_ok_eq a_offset k h_a_max + -- (2) i1 ← b_offset + k. + have h_b_max : b_offset.val + k.val ≤ Std.Usize.max := by + have h_bb_b : b_offset.val + k.val ≤ 16 := by omega + scalar_tac + obtain ⟨i1, h_i1_eq, h_i1_val⟩ := + Layer4PlusFC.usize_add_ok_eq b_offset k h_b_max + -- (3) zeta lookup. + obtain ⟨z, h_z_eq, h_z_v, h_z_bd, h_z_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zeta_i1 h_zi1_lt) + have h_i_lt_16 : i.val < 16 := by rw [h_i_val]; omega + have h_i1_lt_16 : i1.val < 16 := by rw [h_i1_val]; omega + have h_i_ne_i1 : i.val ≠ i1.val := by + rw [h_i_val, h_i1_val] + have : a_offset.val + k.val < b_offset.val + k.val := by omega + omega + -- Bounds at i and i1 via h_acc_undone. + have h_acc_i_undone : acc.1.coefficients.val[i.val]! = re0.coefficients.val[i.val]! := by + apply h_acc_undone i.val h_i_lt_16 + intro j' hj' + refine ⟨?_, ?_⟩ + · rw [h_i_val]; omega + · rw [h_i_val]; omega + have h_acc_i1_undone : acc.1.coefficients.val[i1.val]! = re0.coefficients.val[i1.val]! := by + apply h_acc_undone i1.val h_i1_lt_16 + intro j' hj' + refine ⟨?_, ?_⟩ + · rw [h_i1_val]; omega + · rw [h_i1_val]; omega + -- Per-lane bounds at acc.1.coefs[i] and [i1] via the bound conjunct in h_inv. + have h_acc_at_i_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs ≤ 3328 := + fun ℓ hℓ => h_acc_bnd i.val h_i_lt_16 ℓ hℓ + have h_acc_at_i1_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[i1.val]!).elements.val[ℓ]!).val.natAbs ≤ 3328 := + fun ℓ hℓ => h_acc_bnd i1.val h_i1_lt_16 ℓ hℓ + have h_zeta_bnd : z.val.natAbs ≤ 1664 := h_z_bd + -- Apply the H.0 keystone. + obtain ⟨r_pair, h_r_eq, h_r_a, h_r_b, h_r_undone, h_r_bnd_a, h_r_bnd_b⟩ := + triple_exists_ok_fc (inv_ntt_layer_int_vec_step_reduce_fc + acc.1.coefficients i i1 acc.2 z h_i_lt_16 h_i1_lt_16 h_i_ne_i1 + h_zeta_bnd h_acc_at_i_bnd h_acc_at_i1_bnd) + -- Build the new accumulator. + set acc' : Layer4PlusInnerFC.Acc := + (({ coefficients := r_pair.1 } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector), + r_pair.2) with hacc'_def + -- Compose body. + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i1 a_offset b_offset + { start := k, «end» := step_vec } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := step_vec } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show (do + let i ← a_offset + k + let i1 ← b_offset + k + let i2 ← libcrux_iot_ml_kem.polynomial.zeta zeta_i1 + let (a, scratch1) ← + libcrux_iot_ml_kem.invert_ntt.inv_ntt_layer_int_vec_step_reduce portable_ops_inst + acc.1.coefficients i i1 acc.2 i2 + Result.ok (ControlFlow.cont (({ start := s, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := a } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector), + scratch1))) = _ + rw [h_i_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_i1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_z_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_r_eq]; rfl + apply triple_of_ok_fc h_body + show Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k + (.cont (({ start := s, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer4PlusInnerFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (Layer4PlusInnerFC.inv re0 a_offset b_offset + (Spec.zeta_at zeta_i1.val) s acc').holds + have h_inv_pure : + (∀ j' : Nat, j' < s.val → + lift_chunk (acc'.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!))) + ∧ (∀ j' : Nat, j' < s.val → + lift_chunk (acc'.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < s.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + acc'.1.coefficients.val[k']! = re0.coefficients.val[k']!) + ∧ (∀ k' : Nat, k' < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc'.1.coefficients.val[k']!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · -- (a) a-side butterfly for j' < s.val. + intro j' hj' + rw [hs_val] at hj' + rcases Nat.lt_succ_iff_lt_or_eq.mp hj' with hj'_lt | hj'_eq + · have h_ne_i : a_offset.val + j' ≠ i.val := by rw [h_i_val]; omega + have h_ne_i1 : a_offset.val + j' ≠ i1.val := by rw [h_i1_val]; omega + have h_pos : a_offset.val + j' < 16 := by omega + have h_unchanged : r_pair.1.val[a_offset.val + j']! + = acc.1.coefficients.val[a_offset.val + j']! := + h_r_undone (a_offset.val + j') h_pos h_ne_i h_ne_i1 + show lift_chunk (acc'.1.coefficients.val[a_offset.val + j']!) = _ + show lift_chunk (r_pair.1.val[a_offset.val + j']!) = _ + rw [h_unchanged] + exact h_acc_a j' hj'_lt + · subst hj'_eq + show lift_chunk (acc'.1.coefficients.val[a_offset.val + k.val]!) = _ + show lift_chunk (r_pair.1.val[a_offset.val + k.val]!) = _ + have h_eq_i : a_offset.val + k.val = i.val := by rw [h_i_val] + rw [h_eq_i] + rw [h_r_a] + rw [h_acc_i_undone, h_acc_i1_undone] + rw [h_i_val, h_i1_val] + · -- (b) b-side butterfly for j' < s.val. + intro j' hj' + rw [hs_val] at hj' + rcases Nat.lt_succ_iff_lt_or_eq.mp hj' with hj'_lt | hj'_eq + · have h_ne_i : b_offset.val + j' ≠ i.val := by rw [h_i_val]; omega + have h_ne_i1 : b_offset.val + j' ≠ i1.val := by rw [h_i1_val]; omega + have h_pos : b_offset.val + j' < 16 := by omega + have h_unchanged : r_pair.1.val[b_offset.val + j']! + = acc.1.coefficients.val[b_offset.val + j']! := + h_r_undone (b_offset.val + j') h_pos h_ne_i h_ne_i1 + show lift_chunk (acc'.1.coefficients.val[b_offset.val + j']!) = _ + show lift_chunk (r_pair.1.val[b_offset.val + j']!) = _ + rw [h_unchanged] + exact h_acc_b j' hj'_lt + · subst hj'_eq + show lift_chunk (acc'.1.coefficients.val[b_offset.val + k.val]!) = _ + show lift_chunk (r_pair.1.val[b_offset.val + k.val]!) = _ + have h_eq_i1 : b_offset.val + k.val = i1.val := by rw [h_i1_val] + rw [h_eq_i1] + rw [h_r_b] + rw [h_acc_i_undone, h_acc_i1_undone] + rw [h_i_val, h_i1_val] + rw [h_z_lift] + · -- (c) Other positions unchanged from re0. + intro k' hk' h_not_touched + have hk'_ne_i : k' ≠ i.val := by + have h_at_k : k.val < s.val := by rw [hs_val]; omega + have := (h_not_touched k.val h_at_k).1 + rw [h_i_val]; exact this + have hk'_ne_i1 : k' ≠ i1.val := by + have h_at_k : k.val < s.val := by rw [hs_val]; omega + have := (h_not_touched k.val h_at_k).2 + rw [h_i1_val]; exact this + show acc'.1.coefficients.val[k']! = re0.coefficients.val[k']! + show r_pair.1.val[k']! = re0.coefficients.val[k']! + have h_unchanged := h_r_undone k' hk' hk'_ne_i hk'_ne_i1 + rw [h_unchanged] + apply h_acc_undone k' hk' + intro j' hj' + have h_at_j' : j' < s.val := by rw [hs_val]; omega + exact h_not_touched j' h_at_j' + · -- (d) Per-lane output bound at every chunk. + intro k' hk' ℓ hℓ + show ((acc'.1.coefficients.val[k']!).elements.val[ℓ]!).val.natAbs ≤ 3328 + show ((r_pair.1.val[k']!).elements.val[ℓ]!).val.natAbs ≤ 3328 + by_cases h_ki : k' = i.val + · rw [h_ki]; exact h_r_bnd_a ℓ hℓ + · by_cases h_ki1 : k' = i1.val + · rw [h_ki1]; exact h_r_bnd_b ℓ hℓ + · have h_unchanged : r_pair.1.val[k']! = acc.1.coefficients.val[k']! := + h_r_undone k' hk' h_ki h_ki1 + rw [h_unchanged] + exact h_acc_bnd k' hk' ℓ hℓ + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- None branch: k ≥ step_vec, done. + have hk_ge : k.val ≥ step_vec.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = step_vec.val := by omega + have h_iter_none := Layer4PlusFC.iter_next_none_eq_gen k step_vec hk_ge + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i1 a_offset b_offset + { start := k, «end» := step_vec } acc.1 acc.2 + = .ok (ControlFlow.done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := step_vec } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k (.done acc) + unfold Layer4PlusInnerFC.step_post + show (Layer4PlusInnerFC.inv re0 a_offset b_offset + (Spec.zeta_at zeta_i1.val) step_vec acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!))) + ∧ (∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < step_vec.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + acc.1.coefficients.val[k']! = re0.coefficients.val[k']!) + ∧ (∀ k' : Nat, k' < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[k']!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · intro j' hj'; rw [← hk_eq] at hj'; exact h_acc_a j' hj' + · intro j' hj'; rw [← hk_eq] at hj'; exact h_acc_b j' hj' + · intro k' hk' h_not_touched + apply h_acc_undone k' hk' + intro j' hj' + have h_at_j' : j' < step_vec.val := by rw [← hk_eq]; exact hj' + exact h_not_touched j' h_at_j' + · exact h_acc_bnd + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +/-! ### L3i.5 — Outer loop scaffolding. -/ + +namespace Layer4PlusOuterFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Outer loop accumulator: (zeta_i, re, scratch). -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector × + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- Inverse outer loop invariant. zeta_i DECREMENTS by 1 per outer round. + Per-round zeta = `Spec.zeta_at (zeta_i_0.val - round' - 1)`. -/ +def inv + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_0 step_vec : Std.Usize) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = zeta_i_0.val - k.val + ∧ (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.2.1.coefficients.val[2 * round' * step_vec.val + j']!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!))) + ∧ (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i_0.val - round' - 1))) + ∧ (∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc.2.1.coefficients.val[c]! = re0.coefficients.val[c]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.1.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328)) + +/-- Step-post for the inverse outer loop. -/ +def step_post + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_0 step_vec i_end : Std.Usize) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < i_end.val ∧ iter'.«end» = i_end + ∧ iter'.start.val = k.val + 1 + ∧ (inv re0 zeta_i_0 step_vec iter'.start acc').holds + | .done y => (inv re0 zeta_i_0 step_vec i_end y).holds + +end Layer4PlusOuterFC + +/-- Inverse a-side outer helper: chunks lifted via `re0` at index + `2*k*step_vec + j'` are exactly the original re0 chunks (since these + positions have not yet been touched by outer iter `round' < k`). -/ +theorem outer_acc_inv_a_chunk_eq_re0 + (re0 acc : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) (step_vec : Std.Usize) + (h_undone : ∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc.coefficients.val[c]! = re0.coefficients.val[c]!) + (h_kbound : 2 * k.val * step_vec.val + 2 * step_vec.val ≤ 16) + (j' : Nat) (hj' : j' < step_vec.val) : + acc.coefficients.val[2 * k.val * step_vec.val + j']! + = re0.coefficients.val[2 * k.val * step_vec.val + j']! := by + apply h_undone (2 * k.val * step_vec.val + j') (by omega) + intro round' hround' j'' hj'' + refine ⟨?_, ?_⟩ + · have h1 : 2 * round' * step_vec.val + j'' < 2 * k.val * step_vec.val := by + have : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + omega + · have h1 : 2 * round' * step_vec.val + step_vec.val + j'' < 2 * k.val * step_vec.val := by + have : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + omega + +/-- Inverse b-side outer helper. -/ +theorem outer_acc_inv_b_chunk_eq_re0 + (re0 acc : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) (step_vec : Std.Usize) + (h_undone : ∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc.coefficients.val[c]! = re0.coefficients.val[c]!) + (h_kbound : 2 * k.val * step_vec.val + 2 * step_vec.val ≤ 16) + (h_step_vec_pos : 1 ≤ step_vec.val) + (j' : Nat) (hj' : j' < step_vec.val) : + acc.coefficients.val[2 * k.val * step_vec.val + step_vec.val + j']! + = re0.coefficients.val[2 * k.val * step_vec.val + step_vec.val + j']! := by + apply h_undone (2 * k.val * step_vec.val + step_vec.val + j') (by omega) + intro round' hround' j'' hj'' + refine ⟨?_, ?_⟩ + · have h1 : 2 * round' * step_vec.val + j'' < 2 * k.val * step_vec.val := by + have : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + omega + · have h1 : 2 * round' * step_vec.val + step_vec.val + j'' < 2 * k.val * step_vec.val := by + have : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + omega + +set_option maxHeartbeats 16000000 in +/-- Inverse inner loop spec wrapper: dispatches `loop_range_spec_usize` for the + inner loop, returning the final FC equations on the post poly. -/ +theorem invert_ntt_at_layer_4_plus_inner_loop_fc + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset step_vec : Std.Usize) (zeta_i1 : Std.Usize) + (h_zi1_lt : zeta_i1.val < 128) + (h_step_vec_pos : 1 ≤ step_vec.val) + (h_a_offset_b : a_offset.val + step_vec.val ≤ 16) + (h_b_offset_b : b_offset.val + step_vec.val ≤ 16) + (h_disjoint : a_offset.val + step_vec.val ≤ b_offset.val) + (h_pre_a : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[a_offset.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 3328) + (h_pre_b : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[b_offset.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 3328) + (h_pre_all : ∀ k' : Nat, k' < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[k']!).elements.val[ℓ]!).val.natAbs ≤ 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0 + (vectortraitsOperationsInst := portable_ops_inst) + { start := 0#usize, «end» := step_vec } zeta_i1 re0 scratch a_offset b_offset + ⦃ ⇓ r => ⌜ + (∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!))) + ∧ (∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < step_vec.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + r.1.coefficients.val[k']! = re0.coefficients.val[k']!) + ∧ (∀ k' : Nat, k' < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((r.1.coefficients.val[k']!).elements.val[ℓ]!).val.natAbs ≤ 3328) + ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i1 a_offset b_offset iter1 acc1.1 acc1.2) + (β := Layer4PlusInnerFC.Acc) + (re0, scratch) + 0#usize step_vec + (Layer4PlusInnerFC.inv re0 a_offset b_offset (Spec.zeta_at zeta_i1.val)) + (by + have : (0#usize : Std.Usize).val = 0 := rfl + omega) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_, ?_⟩ + · intro j' hj'; exact absurd hj' (Nat.not_lt_zero j') + · intro j' hj'; exact absurd hj' (Nat.not_lt_zero j') + · intro k' _ _; trivial + · -- Initial bound holds via h_pre_a/h_pre_b for touched indices, + -- and trivially for untouched indices via the input precondition + -- on re0 chunks. Need a global bound on re0. + -- At k=0, acc.1 = re0. We need ((re0.coef.val[k']!).elem.val[ℓ]!).natAbs ≤ 3328. + -- This is NOT covered by h_pre_a/h_pre_b alone (those are only for + -- a_offset+j and b_offset+j). + -- We use a global precondition added below: h_pre_all. + intro k' hk' ℓ hℓ + exact h_pre_all k' hk' ℓ hℓ) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : + (Layer4PlusInnerFC.inv re0 a_offset b_offset + (Spec.zeta_at zeta_i1.val) step_vec r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!))) + ∧ (∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < step_vec.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + r.1.coefficients.val[k']! = re0.coefficients.val[k']!) + ∧ (∀ k' : Nat, k' < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((r.1.coefficients.val[k']!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer4PlusInnerFC.inv] using h_inv_holds + exact h_inv + · -- Step lemma dispatch. + intro acc k _h_ge h_le hinv + have h_step := invert_ntt_at_layer_4_plus_inner_step_lemma_fc re0 a_offset b_offset step_vec + zeta_i1 h_zi1_lt h_step_vec_pos h_a_offset_b h_b_offset_b h_disjoint h_pre_a h_pre_b + acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusInnerFC.step_post] using hP + · have hP : Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusInnerFC.step_post] using hP + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the inverse outer loop. -/ +theorem invert_ntt_at_layer_4_plus_outer_step_lemma_fc + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_0 step_vec i_end : Std.Usize) + (h_pre : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 3328) + (h_step_vec_pos : 1 ≤ step_vec.val) + (h_step_vec_dvd : 2 * i_end.val * step_vec.val = 16) + (h_zeta_lo : i_end.val ≤ zeta_i_0.val) + (h_zeta_hi : zeta_i_0.val ≤ 128) + (acc : Layer4PlusOuterFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ i_end.val) + (h_inv : (Layer4PlusOuterFC.inv re0 zeta_i_0 step_vec k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + step_vec { start := k, «end» := i_end } acc.1 acc.2.1 acc.2.2 + ⦃ ⇓ r => ⌜ Layer4PlusOuterFC.step_post re0 zeta_i_0 step_vec i_end k r ⌝ ⦄ := by + obtain ⟨h_zeta_acc, h_acc_a, h_acc_b, h_acc_undone, h_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0.body + by_cases h_lt : k.val < i_end.val + · -- Some round = k branch. + obtain ⟨s, hs_val, h_iter_some⟩ := + Layer4PlusFC.iter_next_some_eq_gen k i_end h_lt + -- (1) zeta_i1 ← acc.1 - 1. + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_acc1_ge_1 : 1 ≤ acc.1.val := by + rw [h_zeta_acc] + have : k.val < i_end.val := h_lt + have : k.val + 1 ≤ i_end.val := by omega + omega + have h_z_ge : (1#usize : Std.Usize).val ≤ acc.1.val := by rw [h_um]; omega + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + Layer2FC.usize_sub_ok_eq acc.1 1#usize h_z_ge + have h_zi1_arith : zi1.val = zeta_i_0.val - k.val - 1 := by + rw [h_zi1_val, h_um, h_zeta_acc] + have h_zi1_lt_128 : zi1.val < 128 := by + rw [h_zi1_arith]; omega + -- (2) i ← round * 2. + have h_um2 : (2#usize : Std.Usize).val = 2 := rfl + have h_i_end_le_16 : i_end.val ≤ 16 := by + have : i_end.val * step_vec.val * 2 = 16 := by rw [Nat.mul_assoc] at h_step_vec_dvd; nlinarith + nlinarith + have h_i_max : k.val * (2#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um2] + have h_k_b : k.val * 2 ≤ 16 := by + have : k.val ≤ 8 := by + have : i_end.val ≤ 8 := by + have : i_end.val * step_vec.val * 2 = 16 := by rw [Nat.mul_assoc] at h_step_vec_dvd; nlinarith + nlinarith + omega + omega + scalar_tac + obtain ⟨ii, h_ii_eq, h_ii_val⟩ := + Layer4PlusFC.usize_mul_ok_eq k 2#usize h_i_max + have h_ii_arith : ii.val = 2 * k.val := by rw [h_ii_val, h_um2, Nat.mul_comm] + -- (3) a_offset ← ii * step_vec. + have h_a_max : ii.val * step_vec.val ≤ Std.Usize.max := by + rw [h_ii_arith] + have h_b : 2 * k.val * step_vec.val ≤ 16 := by + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + scalar_tac + obtain ⟨ao, h_ao_eq, h_ao_val⟩ := + Layer4PlusFC.usize_mul_ok_eq ii step_vec h_a_max + have h_ao_arith : ao.val = 2 * k.val * step_vec.val := by + rw [h_ao_val, h_ii_arith] + -- (4) b_offset ← a_offset + step_vec. + have h_b_max : ao.val + step_vec.val ≤ Std.Usize.max := by + rw [h_ao_arith] + have h_b : 2 * k.val * step_vec.val + step_vec.val ≤ 16 := by + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + scalar_tac + obtain ⟨bo, h_bo_eq, h_bo_val⟩ := + Layer4PlusFC.usize_add_ok_eq ao step_vec h_b_max + have h_bo_arith : bo.val = 2 * k.val * step_vec.val + step_vec.val := by + rw [h_bo_val, h_ao_arith] + have h_a_offset_b : ao.val + step_vec.val ≤ 16 := by + rw [h_ao_arith] + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + have h_b_offset_b : bo.val + step_vec.val ≤ 16 := by + rw [h_bo_arith] + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + have h_disjoint : ao.val + step_vec.val ≤ bo.val := by + rw [h_ao_arith, h_bo_arith] + have h_2kstep_bnd : 2 * k.val * step_vec.val + 2 * step_vec.val ≤ 16 := by + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + have h_acc_a_eq : ∀ j : Nat, j < step_vec.val → + acc.2.1.coefficients.val[ao.val + j]! = re0.coefficients.val[ao.val + j]! := by + intro j hj + rw [h_ao_arith] + exact outer_acc_inv_a_chunk_eq_re0 re0 acc.2.1 k step_vec h_acc_undone + h_2kstep_bnd j hj + have h_acc_b_eq : ∀ j : Nat, j < step_vec.val → + acc.2.1.coefficients.val[bo.val + j]! = re0.coefficients.val[bo.val + j]! := by + intro j hj + rw [h_bo_arith] + exact outer_acc_inv_b_chunk_eq_re0 re0 acc.2.1 k step_vec h_acc_undone + h_2kstep_bnd h_step_vec_pos j hj + have h_pre_a : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.1.coefficients.val[ao.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 3328 := by + intro j hj ℓ hℓ + rw [h_acc_a_eq j hj] + apply h_pre _ _ ℓ hℓ + rw [h_ao_arith]; omega + have h_pre_b : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.1.coefficients.val[bo.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 3328 := by + intro j hj ℓ hℓ + rw [h_acc_b_eq j hj] + apply h_pre _ _ ℓ hℓ + rw [h_bo_arith]; omega + -- Dispatch inner loop. + have h_inner := invert_ntt_at_layer_4_plus_inner_loop_fc acc.2.1 acc.2.2 ao bo step_vec zi1 + h_zi1_lt_128 h_step_vec_pos h_a_offset_b h_b_offset_b h_disjoint h_pre_a h_pre_b h_acc_bnd + obtain ⟨r_pair, h_r_eq, h_r_a, h_r_b, h_r_undone, h_r_bnd⟩ := + triple_exists_ok_fc h_inner + set acc' : Layer4PlusOuterFC.Acc := + (zi1, r_pair.1, r_pair.2) with hacc'_def + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + step_vec { start := k, «end» := i_end } acc.1 acc.2.1 acc.2.2 + = .ok (ControlFlow.cont (({ start := s, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := i_end } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show (do + let zi1' ← acc.1 - 1#usize + let ii' ← k * 2#usize + let ao' ← ii' * step_vec + let bo' ← ao' + step_vec + let (re1, scratch1) ← + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0_loop0 + (vectortraitsOperationsInst := portable_ops_inst) + { start := 0#usize, «end» := step_vec } zi1' acc.2.1 acc.2.2 ao' bo' + .ok (ControlFlow.cont (({ start := s, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize), + zi1', re1, scratch1))) = _ + rw [h_zi1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_ii_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_ao_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_bo_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_r_eq]; rfl + apply triple_of_ok_fc h_body + show Layer4PlusOuterFC.step_post re0 zeta_i_0 step_vec i_end k + (.cont (({ start := s, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer4PlusOuterFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (Layer4PlusOuterFC.inv re0 zeta_i_0 step_vec s acc').holds + have h_inv_pure : + acc'.1.val = zeta_i_0.val - s.val + ∧ (∀ round' : Nat, round' < s.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc'.2.1.coefficients.val[2 * round' * step_vec.val + j']!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!))) + ∧ (∀ round' : Nat, round' < s.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc'.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i_0.val - round' - 1))) + ∧ (∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < s.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc'.2.1.coefficients.val[c]! = re0.coefficients.val[c]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc'.2.1.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + · -- zeta thread: zi1.val = zeta_i_0.val - (k.val + 1) = zeta_i_0.val - s.val. + show zi1.val = zeta_i_0.val - s.val + rw [h_zi1_arith, hs_val]; omega + · -- a-side butterflies for round' < s.val. + intro round' hround' j' hj' + rw [hs_val] at hround' + rcases Nat.lt_succ_iff_lt_or_eq.mp hround' with hround'_lt | hround'_eq + · have h_pos : 2 * round' * step_vec.val + j' < 16 := by + have h_rb : 2 * round' * step_vec.val + 2 * step_vec.val + ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_ne_a : ∀ j : Nat, j < step_vec.val → + 2 * round' * step_vec.val + j' ≠ ao.val + j := by + intro j hj + rw [h_ao_arith] + have h1 : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_ne_b : ∀ j : Nat, j < step_vec.val → + 2 * round' * step_vec.val + j' ≠ bo.val + j := by + intro j hj + rw [h_bo_arith] + have h1 : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_step_unc : r_pair.1.coefficients.val[2 * round' * step_vec.val + j']! + = acc.2.1.coefficients.val[2 * round' * step_vec.val + j']! := + h_r_undone (2 * round' * step_vec.val + j') h_pos + (fun j hj => ⟨h_ne_a j hj, h_ne_b j hj⟩) + show lift_chunk (acc'.2.1.coefficients.val[2 * round' * step_vec.val + j']!) = _ + show lift_chunk (r_pair.1.coefficients.val[2 * round' * step_vec.val + j']!) = _ + rw [h_step_unc] + exact h_acc_a round' hround'_lt j' hj' + · subst hround'_eq + show lift_chunk (acc'.2.1.coefficients.val[2 * k.val * step_vec.val + j']!) = _ + show lift_chunk (r_pair.1.coefficients.val[2 * k.val * step_vec.val + j']!) = _ + rw [show (2 * k.val * step_vec.val + j' : Nat) = ao.val + j' from by rw [h_ao_arith]] + rw [h_r_a j' hj'] + rw [h_acc_a_eq j' hj', h_acc_b_eq j' hj'] + rw [show (ao.val + j' : Nat) = 2 * k.val * step_vec.val + j' from by rw [h_ao_arith]] + rw [show (bo.val + j' : Nat) = 2 * k.val * step_vec.val + step_vec.val + j' from by rw [h_bo_arith]] + · -- b-side butterflies for round' < s.val. + intro round' hround' j' hj' + rw [hs_val] at hround' + rcases Nat.lt_succ_iff_lt_or_eq.mp hround' with hround'_lt | hround'_eq + · have h_pos : 2 * round' * step_vec.val + step_vec.val + j' < 16 := by + have h_rb : 2 * round' * step_vec.val + 2 * step_vec.val + ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_ne_a : ∀ j : Nat, j < step_vec.val → + 2 * round' * step_vec.val + step_vec.val + j' ≠ ao.val + j := by + intro j hj + rw [h_ao_arith] + have h1 : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_ne_b : ∀ j : Nat, j < step_vec.val → + 2 * round' * step_vec.val + step_vec.val + j' ≠ bo.val + j := by + intro j hj + rw [h_bo_arith] + have h1 : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_step_unc : + r_pair.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']! + = acc.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']! := + h_r_undone (2 * round' * step_vec.val + step_vec.val + j') h_pos + (fun j hj => ⟨h_ne_a j hj, h_ne_b j hj⟩) + show lift_chunk (acc'.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) = _ + show lift_chunk (r_pair.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) = _ + rw [h_step_unc] + exact h_acc_b round' hround'_lt j' hj' + · subst hround'_eq + show lift_chunk (acc'.2.1.coefficients.val[2 * k.val * step_vec.val + step_vec.val + j']!) = _ + show lift_chunk (r_pair.1.coefficients.val[2 * k.val * step_vec.val + step_vec.val + j']!) = _ + rw [show (2 * k.val * step_vec.val + step_vec.val + j' : Nat) = bo.val + j' from by rw [h_bo_arith]] + rw [h_r_b j' hj'] + rw [h_acc_a_eq j' hj', h_acc_b_eq j' hj'] + rw [show (ao.val + j' : Nat) = 2 * k.val * step_vec.val + j' from by rw [h_ao_arith]] + rw [show (bo.val + j' : Nat) = 2 * k.val * step_vec.val + step_vec.val + j' from by rw [h_bo_arith]] + rw [show zi1.val = zeta_i_0.val - k.val - 1 from h_zi1_arith] + · -- Untouched chunks. + intro c hc h_not_touched + show acc'.2.1.coefficients.val[c]! = re0.coefficients.val[c]! + show r_pair.1.coefficients.val[c]! = re0.coefficients.val[c]! + have h_at_k : k.val < s.val := by rw [hs_val]; omega + have h_ne_a_k : ∀ j : Nat, j < step_vec.val → c ≠ ao.val + j := by + intro j hj; rw [h_ao_arith] + exact (h_not_touched k.val h_at_k j hj).1 + have h_ne_b_k : ∀ j : Nat, j < step_vec.val → c ≠ bo.val + j := by + intro j hj; rw [h_bo_arith] + exact (h_not_touched k.val h_at_k j hj).2 + have h_step_unc : r_pair.1.coefficients.val[c]! = acc.2.1.coefficients.val[c]! := + h_r_undone c hc (fun j hj => ⟨h_ne_a_k j hj, h_ne_b_k j hj⟩) + rw [h_step_unc] + apply h_acc_undone c hc + intro round' hround' j' hj' + have h_at_r : round' < s.val := by rw [hs_val]; omega + exact h_not_touched round' h_at_r j' hj' + · -- Per-lane output bound at every chunk: inner loop result already + -- has the bound on r_pair.1 = acc'.2.1. + intro c hc ℓ hℓ + show ((acc'.2.1.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328 + show ((r_pair.1.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328 + exact h_r_bnd c hc ℓ hℓ + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- None branch: k ≥ i_end, done. + have hk_ge : k.val ≥ i_end.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = i_end.val := by omega + have h_iter_none := Layer4PlusFC.iter_next_none_eq_gen k i_end hk_ge + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + step_vec { start := k, «end» := i_end } acc.1 acc.2.1 acc.2.2 + = .ok (ControlFlow.done (acc.1, acc.2.1, acc.2.2)) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := i_end } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2.1, acc.2.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer4PlusOuterFC.step_post re0 zeta_i_0 step_vec i_end k (.done acc) + unfold Layer4PlusOuterFC.step_post + show (Layer4PlusOuterFC.inv re0 zeta_i_0 step_vec i_end acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + acc.1.val = zeta_i_0.val - i_end.val + ∧ (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.2.1.coefficients.val[2 * round' * step_vec.val + j']!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!))) + ∧ (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i_0.val - round' - 1))) + ∧ (∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc.2.1.coefficients.val[c]! = re0.coefficients.val[c]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.1.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + · rw [h_zeta_acc, hk_eq] + · intro round' hround'; rw [← hk_eq] at hround'; exact h_acc_a round' hround' + · intro round' hround'; rw [← hk_eq] at hround'; exact h_acc_b round' hround' + · intro c hc h_nt + apply h_acc_undone c hc + intro round' hround' j' hj' + have : round' < i_end.val := by rw [← hk_eq]; exact hround' + exact h_nt round' this j' hj' + · exact h_acc_bnd + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +@[spec high] +theorem invert_ntt_at_layer_4_plus_portable_fc + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (layer : Std.Usize) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_layer : 4 ≤ layer.val ∧ layer.val ≤ 7) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 3328) + (h_zeta : (128 >>> layer.val) ≤ zeta_i.val ∧ zeta_i.val ≤ 128) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i re layer scratch + ⦃ ⇓ p => ⌜ p.1.val = zeta_i.val - 128 >>> layer.val + ∧ lift_poly p.2.1 = Spec.invert_ntt_layer_4_plus_pure (lift_poly re) zeta_i layer + ∧ (∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328) ⌝ ⦄ := by + obtain ⟨h_layer_lo, h_layer_hi⟩ := h_layer + obtain ⟨h_zeta_lo, h_zeta_hi⟩ := h_zeta + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus + -- Resolve step ← 1 <<< layer. + have h_usize_bits : (Aeneas.Std.UScalarTy.Usize.numBits : Nat) = System.Platform.numBits := rfl + have h_layer_bits : layer.val < Aeneas.Std.UScalarTy.Usize.numBits := by + have h_p := System.Platform.numBits_eq + rcases h_p with h32 | h64 + · rw [h_usize_bits, h32]; omega + · rw [h_usize_bits, h64]; omega + have h_size_eq : Aeneas.Std.UScalar.size Aeneas.Std.UScalarTy.Usize = 2 ^ System.Platform.numBits := by + simp [Std.Usize.size, Usize.numBits] + have h_one_shl_pow : ((1#usize : Std.Usize).val <<< layer.val) < 2 ^ System.Platform.numBits := by + have h_one_eq : (1#usize : Std.Usize).val = 1 := rfl + rw [h_one_eq, Nat.shiftLeft_eq, Nat.one_mul] + have h_p := System.Platform.numBits_eq + rcases h_p with h32 | h64 + · rw [h32]; exact Nat.pow_lt_pow_right (by decide) (by omega) + · rw [h64]; exact Nat.pow_lt_pow_right (by decide) (by omega) + have h_step_ex : ∃ step : Std.Usize, + ((1#usize : Std.Usize) <<< layer : Result Std.Usize) = .ok step + ∧ step.val = 1 <<< layer.val := by + have hT := Aeneas.Std.UScalar.ShiftLeft_spec (1#usize : Std.Usize) layer + (Aeneas.Std.UScalar.size Aeneas.Std.UScalarTy.Usize) h_layer_bits rfl + obtain ⟨z, h_eq, h_v_mod, _h_bv⟩ := Std.WP.spec_imp_exists hT + refine ⟨z, h_eq, ?_⟩ + have h_one_eq : (1#usize : Std.Usize).val = 1 := rfl + rw [h_v_mod, h_one_eq, h_size_eq, Nat.mod_eq_of_lt] + rw [h_one_eq] at h_one_shl_pow + exact h_one_shl_pow + obtain ⟨step, h_step_eq, h_step_val⟩ := h_step_ex + rw [h_step_eq] + simp only [Aeneas.Std.bind_tc_ok] + -- Unfold FIELD_ELEMENTS_IN_VECTOR (= 16#usize) so we can use UScalar.div_spec. + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + -- Resolve step_vec ← step / 16. + have h_16_nz : ((16#usize : Std.Usize).val : Nat) ≠ 0 := by decide + have h_step_pos : 1 ≤ step.val := by + rw [h_step_val, Nat.shiftLeft_eq, Nat.one_mul] + exact Nat.one_le_pow _ _ (by decide : (0:Nat) < 2) + obtain ⟨step_vec, h_step_vec_eq, h_step_vec_val⟩ := + Aeneas.Std.UScalar.div_spec step h_16_nz + rw [h_step_vec_eq] + simp only [Aeneas.Std.bind_tc_ok] + have h_step_vec_arith : step_vec.val = (1 <<< layer.val) / 16 := by + have h_16_eq : (16#usize : Std.Usize).val = 16 := rfl + rw [h_step_vec_val, h_step_val, h_16_eq] + -- Resolve i_end ← 128 >>> layer. + obtain ⟨i_end, h_i_end_eq, h_i_end_val, _h_i_end_bv⟩ := + Std.WP.spec_imp_exists (Aeneas.Std.UScalar.ShiftRight_spec (128#usize : Std.Usize) layer + h_layer_bits) + rw [h_i_end_eq] + have h_i_end_arith : i_end.val = 128 >>> layer.val := h_i_end_val + have h_step_vec_pos : 1 ≤ step_vec.val := by + rw [h_step_vec_arith] + interval_cases layer.val <;> decide + have h_step_vec_dvd : 2 * i_end.val * step_vec.val = 16 := by + rw [h_i_end_arith, h_step_vec_arith] + interval_cases layer.val <;> decide + have h_i_end_pos : 1 ≤ i_end.val := by + rw [h_i_end_arith] + interval_cases layer.val <;> decide + have h_zeta_lo' : i_end.val ≤ zeta_i.val := by + rw [h_i_end_arith]; exact h_zeta_lo + -- Unfold outer loop and apply loop_range_spec_usize. + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.invert_ntt.invert_ntt_at_layer_4_plus_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) step_vec + iter1 acc1.1 acc1.2.1 acc1.2.2) + (β := Layer4PlusOuterFC.Acc) + (zeta_i, re, scratch) + 0#usize i_end + (Layer4PlusOuterFC.inv re zeta_i step_vec) + (by + have h_zero : (0#usize : Std.Usize).val = 0 := rfl + omega) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + · show zeta_i.val = zeta_i.val - (0#usize : Std.Usize).val + show zeta_i.val = zeta_i.val - 0 + omega + · intro round' hround' _ _ + exact absurd hround' (Nat.not_lt_zero round') + · intro round' hround' _ _ + exact absurd hround' (Nat.not_lt_zero round') + · intro _ _ _; trivial + · -- Initial bound: acc.2.1 = re at k=0, so bound from h_bnd. + intro c hc ℓ hℓ + exact h_bnd c hc ℓ hℓ) + ?_) + · -- Post entailment: at k = i_end, build chunks_arr matching Spec. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Layer4PlusOuterFC.inv re zeta_i step_vec i_end r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + r.1.val = zeta_i.val - i_end.val + ∧ (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.2.1.coefficients.val[2 * round' * step_vec.val + j']!) + = Spec.chunk_inv_pair_butterfly_a_pure + (lift_chunk (re.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!))) + ∧ (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) + = Spec.chunk_inv_pair_butterfly_b_pure + (lift_chunk (re.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i.val - round' - 1))) + ∧ (∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + r.2.1.coefficients.val[c]! = re.coefficients.val[c]!) + ∧ (∀ c : Nat, c < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((r.2.1.coefficients.val[c]!).elements.val[ℓ]!).val.natAbs ≤ 3328) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer4PlusOuterFC.inv] using h_inv_holds + obtain ⟨h_zeta_done, h_done_a, h_done_b, _h_done_undone, h_done_bnd⟩ := h_inv + -- Build chunks_arr matching the Spec layout. + unfold Spec.invert_ntt_layer_4_plus_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun c => + Spec.chunk_inv_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at (lift_poly re))) + (by simp)) + layer + (fun group => Spec.zeta_at (zeta_i.val - 1 - group)) + c)) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16; simp + have h_chunks0_at : ∀ k : Nat, k < 16 → + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at (lift_poly re))) + (by simp)).val[k]! + = lift_chunk (re.coefficients.val[k]!) := by + intro k hk + have h_len_map : ((List.range 16).map (Spec.chunk_at (lift_poly re))).length = 16 := by simp + show ((List.range 16).map (Spec.chunk_at (lift_poly re)))[k]! = _ + rw [getElem!_pos _ k (by rw [h_len_map]; exact hk)] + rw [List.getElem_map, List.getElem_range] + exact chunk_at_lift_poly_fc re k hk + have h_chunks_get : ∀ c : Nat, (hc : c < 16) → + chunks_arr.val[c]'(by rw [h_chunks_len]; exact hc) + = lift_chunk (r.2.1.coefficients.val[c]!) := by + intro c hc + show ((List.range 16).map (fun c => + Spec.chunk_inv_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at (lift_poly re))) + (by simp)) + layer + (fun group => Spec.zeta_at (zeta_i.val - 1 - group)) + c))[c]'_ = _ + rw [List.getElem_map, List.getElem_range] + unfold Spec.chunk_inv_at_layer_4_plus_pure + set sv := (1 <<< layer.val) / 16 with hsv_def + have hsv_eq : sv = step_vec.val := by rw [hsv_def, h_step_vec_arith] + simp only [] + set group := c / (2 * sv) + set offset := c % (2 * sv) + have h_2sv_pos : 0 < 2 * sv := by rw [hsv_eq]; omega + have h_c_eq : 2 * sv * group + offset = c := by + show 2 * sv * (c / (2 * sv)) + c % (2 * sv) = c + exact Nat.div_add_mod c (2 * sv) + have h_off_lt : offset < 2 * sv := Nat.mod_lt _ h_2sv_pos + have h_16_eq : 2 * i_end.val * sv = 16 := by + rw [hsv_eq]; exact h_step_vec_dvd + have h_group_lt : group < i_end.val := by + by_contra h_ge + push Not at h_ge + have h_ge2 : 2 * sv * i_end.val ≤ 2 * sv * group := Nat.mul_le_mul_left _ h_ge + have h_c_ge : c ≥ 2 * sv * i_end.val := by + have : 2 * sv * group ≤ c := by omega + omega + have h_rw : 2 * sv * i_end.val = 16 := by + have h : 2 * i_end.val * sv = 2 * sv * i_end.val := by ring + omega + omega + by_cases h_off_lt_sv : offset < sv + · -- a-side. + simp only [if_pos h_off_lt_sv] + have h_c_lt_16 : c < 16 := hc + have h_c_plus_sv_lt_16 : c + sv < 16 := by + have h_succ : 2 * sv * (group + 1) ≤ 2 * sv * i_end.val := Nat.mul_le_mul_left _ h_group_lt + have h_split : 2 * sv * (group + 1) = 2 * sv * group + 2 * sv := by ring + have h_eq_16 : 2 * sv * i_end.val = 16 := by + have : 2 * i_end.val * sv = 2 * sv * i_end.val := by ring + omega + omega + rw [h_chunks0_at c h_c_lt_16, h_chunks0_at (c + sv) h_c_plus_sv_lt_16] + have h_c_eq_a : c = 2 * group * step_vec.val + offset := by + rw [← hsv_eq] + calc c = 2 * sv * group + offset := h_c_eq.symm + _ = 2 * group * sv + offset := by ring_nf + have h_csv_eq_a : c + sv = 2 * group * step_vec.val + step_vec.val + offset := by + rw [h_c_eq_a]; rw [hsv_eq]; ring + have h_off_lt_sv' : offset < step_vec.val := by rw [← hsv_eq]; exact h_off_lt_sv + have h_done := h_done_a group h_group_lt offset h_off_lt_sv' + rw [h_csv_eq_a, h_c_eq_a] + exact h_done.symm + · -- b-side. + simp only [if_neg h_off_lt_sv] + push Not at h_off_lt_sv + set j' := offset - sv with hj'_def + have hj'_lt_sv : j' < sv := by + show offset - sv < sv; omega + have h_off_eq : offset = sv + j' := by + show offset = sv + (offset - sv); omega + have h_c_lt_16 : c < 16 := hc + have h_c_minus_sv_lt_16 : c - sv < 16 := by omega + have h_c_eq_b : c = 2 * group * step_vec.val + step_vec.val + j' := by + rw [← hsv_eq] + have : c = 2 * sv * group + (sv + j') := by rw [← h_off_eq]; exact h_c_eq.symm + calc c = 2 * sv * group + (sv + j') := this + _ = 2 * group * sv + sv + j' := by ring + have h_cmsv_eq_b : c - sv = 2 * group * step_vec.val + j' := by + have h_sv_le_c : sv ≤ c := by + calc sv ≤ sv + j' := Nat.le_add_right _ _ + _ = offset := h_off_eq.symm + _ ≤ 2 * sv * group + offset := Nat.le_add_left _ _ + _ = c := h_c_eq + rw [← hsv_eq] + have h_full : c - sv = (2 * sv * group + (sv + j')) - sv := by rw [← h_off_eq, h_c_eq] + rw [h_full] + have h_simp : 2 * sv * group + (sv + j') - sv = 2 * sv * group + j' := by omega + rw [h_simp]; ring + rw [h_chunks0_at (c - sv) h_c_minus_sv_lt_16, h_chunks0_at c h_c_lt_16] + have h_j'_lt : j' < step_vec.val := by rw [← hsv_eq]; exact hj'_lt_sv + have h_done := h_done_b group h_group_lt j' h_j'_lt + rw [h_cmsv_eq_b, h_c_eq_b] + -- The zeta expression: spec uses `zeta_i.val - 1 - group`, + -- invariant uses `zeta_i.val - group - 1`. These are equal. + rw [show zeta_i.val - 1 - group = zeta_i.val - group - 1 by omega] + exact h_done.symm + have h_final := flatten_chunks_eq_lift_poly_fc r.2.1 chunks_arr h_chunks_len h_chunks_get + exact ⟨by rw [h_zeta_done, h_i_end_arith], h_final.symm, h_done_bnd⟩ + · -- Step lemma dispatch. + intro acc k _h_ge h_le hinv + have h_step := invert_ntt_at_layer_4_plus_outer_step_lemma_fc re zeta_i step_vec i_end + h_bnd h_step_vec_pos h_step_vec_dvd h_zeta_lo' h_zeta_hi acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer4PlusOuterFC.step_post re zeta_i step_vec i_end k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusOuterFC.step_post] using hP + · have hP : Layer4PlusOuterFC.step_post re zeta_i step_vec i_end k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusOuterFC.step_post] using hP + +/-! ### L3i.6 — `invert_ntt_montgomery` composer (Task I). + + Top-level inverse-NTT composer. 7-step bind chain through the closed + layer FC equations (F + G.2 + G.3 + H.1 instantiated at 4 different + layer values). Closest yardstick: forward `ntt_binomially_sampled_ring_element_fc` (~247 LOC). + + Impl: + ``` + let zeta_i := 256 / 2 = 128 + let (zeta_i1, re1) := invert_ntt_at_layer_1 portable zeta_i re + let (zeta_i2, re2) := invert_ntt_at_layer_2 portable zeta_i1 re1 + let (zeta_i3, re3) := invert_ntt_at_layer_3 portable zeta_i2 re2 + let (zeta_i4, re4, sc1) := invert_ntt_at_layer_4_plus portable zeta_i3 re3 4 sc + let (zeta_i5, re5, sc2) := invert_ntt_at_layer_4_plus portable zeta_i4 re4 5 sc1 + let (zeta_i6, re6, sc3) := invert_ntt_at_layer_4_plus portable zeta_i5 re5 6 sc2 + let (_, re7, sc4) := invert_ntt_at_layer_4_plus portable zeta_i6 re6 7 sc3 + ok (re7, sc4) + ``` + + zeta_i thread: 128 → 64 → 32 → 16 → 8 → 4 → 2 → 1. -/ +set_option maxHeartbeats 16000000 in +@[spec] +theorem invert_ntt_montgomery_fc + {K : Std.Usize} + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery + K (vectortraitsOperationsInst := portable_ops_inst) re scratch + ⦃ ⇓ p => ⌜ lift_poly p.1 = Spec.invert_ntt_montgomery_pure (lift_poly re) + ∧ (∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328) ⌝ ⦄ := by + -- `h_bnd` is already ≤ 13312; keep alias for readability at layer-1 call site. + have h_bnd_loose : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312 := h_bnd + -- ============================================================= + -- Step 0: resolve `let zeta_i ← constants.COEFFICIENTS_IN_RING_ELEMENT / 2` + -- = `256#usize / 2#usize = .ok 128#usize`. + -- ============================================================= + have h_div : (libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + / (2#usize : Std.Usize) : Result Std.Usize) + = .ok (128#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + have h_2_nz : ((2#usize : Std.Usize).val : Nat) ≠ 0 := by decide + obtain ⟨z, hz_eq, hz_v⟩ := + Aeneas.Std.UScalar.div_spec (256#usize : Std.Usize) h_2_nz + have hz_val : (↑z : Nat) = 128 := by rw [hz_v]; decide + have hz_eq128 : z = (128#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + show z.val = (128#usize : Std.Usize).val + rw [hz_val]; decide + rw [hz_eq, hz_eq128] + -- ============================================================= + -- Step 1: invert_ntt_at_layer_1. zeta_i = 128 → 64. bound stays ≤ 3328. + -- ============================================================= + obtain ⟨⟨zeta_i1, re1⟩, h1_eq, h1_zout, h1_fc, h1_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_at_layer_1_portable_fc (128#usize : Std.Usize) re + h_bnd_loose (by decide) (by decide)) + dsimp only at h1_zout h1_fc h1_bnd + have h_zeta_i1 : zeta_i1.val = 64 := by rw [h1_zout]; decide + -- ============================================================= + -- Step 2: invert_ntt_at_layer_2. zeta_i = 64 → 32. bound stays ≤ 3328. + -- ============================================================= + have h_re1_loose : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re1.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312 := by + intro chunk hc k hk + have := h1_bnd chunk hc k hk + omega + obtain ⟨⟨zeta_i2, re2⟩, h2_eq, h2_zout, h2_fc, h2_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_at_layer_2_portable_fc zeta_i1 re1 + h_re1_loose h1_bnd (by rw [h_zeta_i1]; decide) + (by rw [h_zeta_i1]; decide)) + dsimp only at h2_zout h2_fc h2_bnd + have h_zeta_i2 : zeta_i2.val = 32 := by rw [h2_zout, h_zeta_i1] + -- ============================================================= + -- Step 3: invert_ntt_at_layer_3. zeta_i = 32 → 16. bound stays ≤ 3328. + -- ============================================================= + have h_re2_loose : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re2.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312 := by + intro chunk hc k hk + have := h2_bnd chunk hc k hk + omega + obtain ⟨⟨zeta_i3, re3⟩, h3_eq, h3_zout, h3_fc, h3_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_at_layer_3_portable_fc zeta_i2 re2 + h_re2_loose h2_bnd (by rw [h_zeta_i2]; decide) + (by rw [h_zeta_i2]; decide)) + dsimp only at h3_zout h3_fc h3_bnd + have h_zeta_i3 : zeta_i3.val = 16 := by rw [h3_zout, h_zeta_i2] + -- ============================================================= + -- Step 4: invert_ntt_at_layer_4_plus (layer = 4). zeta_i = 16 → 8. + -- 128 >>> 4 = 8. + -- ============================================================= + obtain ⟨⟨zeta_i4, re4, scratch1⟩, h4_eq, h4_zout, h4_fc, h4_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_at_layer_4_plus_portable_fc zeta_i3 re3 (4#usize : Std.Usize) scratch + (by decide) h3_bnd + (by refine ⟨?_, ?_⟩ <;> · rw [h_zeta_i3]; decide)) + dsimp only at h4_zout h4_fc h4_bnd + have h_zeta_i4 : zeta_i4.val = 8 := by + rw [h4_zout, h_zeta_i3]; decide + -- ============================================================= + -- Step 5: invert_ntt_at_layer_4_plus (layer = 5). zeta_i = 8 → 4. + -- 128 >>> 5 = 4. + -- ============================================================= + obtain ⟨⟨zeta_i5, re5, scratch2⟩, h5_eq, h5_zout, h5_fc, h5_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_at_layer_4_plus_portable_fc zeta_i4 re4 (5#usize : Std.Usize) scratch1 + (by decide) h4_bnd + (by refine ⟨?_, ?_⟩ <;> · rw [h_zeta_i4]; decide)) + dsimp only at h5_zout h5_fc h5_bnd + have h_zeta_i5 : zeta_i5.val = 4 := by + rw [h5_zout, h_zeta_i4]; decide + -- ============================================================= + -- Step 6: invert_ntt_at_layer_4_plus (layer = 6). zeta_i = 4 → 2. + -- 128 >>> 6 = 2. + -- ============================================================= + obtain ⟨⟨zeta_i6, re6, scratch3⟩, h6_eq, h6_zout, h6_fc, h6_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_at_layer_4_plus_portable_fc zeta_i5 re5 (6#usize : Std.Usize) scratch2 + (by decide) h5_bnd + (by refine ⟨?_, ?_⟩ <;> · rw [h_zeta_i5]; decide)) + dsimp only at h6_zout h6_fc h6_bnd + have h_zeta_i6 : zeta_i6.val = 2 := by + rw [h6_zout, h_zeta_i5]; decide + -- ============================================================= + -- Step 7: invert_ntt_at_layer_4_plus (layer = 7). zeta_i = 2 → 1. + -- 128 >>> 7 = 1. + -- ============================================================= + obtain ⟨⟨_zeta_i7, re7, scratch4⟩, h7_eq, _h7_zout, h7_fc, h7_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_at_layer_4_plus_portable_fc zeta_i6 re6 (7#usize : Std.Usize) scratch3 + (by decide) h6_bnd + (by refine ⟨?_, ?_⟩ <;> · rw [h_zeta_i6]; decide)) + dsimp only at h7_fc h7_bnd + -- ============================================================= + -- Compose: derive the full impl `do`-block equation by simp-folding + -- all step equations into the unfolded body. + -- ============================================================= + have h_body : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery K + (vectortraitsOperationsInst := portable_ops_inst) re scratch + = .ok (re7, scratch4) := by + unfold libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery + simp [h_div, h1_eq, h2_eq, h3_eq, h4_eq, h5_eq, h6_eq, h7_eq] + apply triple_of_ok_fc h_body + -- POST is now a conjunction: equality (proved below) ∧ per-lane bound + -- (≤ 3328, exactly the layer-7 output bound `h7_bnd`, since `p.1 = re7`). + refine ⟨?_, h7_bnd⟩ + -- ============================================================= + -- Prove lift_poly equation by chaining FC equations through + -- `Spec.invert_ntt_montgomery_pure`. + -- ============================================================= + show lift_poly re7 = Spec.invert_ntt_montgomery_pure (lift_poly re) + unfold Spec.invert_ntt_montgomery_pure + -- Identify each zeta_i with the spec's literal value. + have h_zeta_eq1 : zeta_i1 = (64#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + rw [h_zeta_i1]; decide + have h_zeta_eq2 : zeta_i2 = (32#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + rw [h_zeta_i2]; decide + have h_zeta_eq3 : zeta_i3 = (16#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + rw [h_zeta_i3]; decide + have h_zeta_eq4 : zeta_i4 = (8#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + rw [h_zeta_i4]; decide + have h_zeta_eq5 : zeta_i5 = (4#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + rw [h_zeta_i5]; decide + have h_zeta_eq6 : zeta_i6 = (2#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + rw [h_zeta_i6]; decide + rw [h7_fc, h6_fc, h5_fc, h4_fc, h3_fc, h2_fc, h1_fc, + h_zeta_eq1, h_zeta_eq2, h_zeta_eq3, h_zeta_eq4, h_zeta_eq5, h_zeta_eq6] + + +/-- L3.4 — `ntt_vector_u` driver (4 layer_4_plus calls + 3 dedicated layers + + barrett reduce, used for the encryption "u" vector NTT). Note that the + impl's first call is `ntt_at_layer_4_plus(0, layer=7)` (Mont multiply + through `ZETAS_TIMES_MONTGOMERY_R[1]`), not the dedicated + `ntt_at_layer_7` (plain multiply with `-1600`). The two paths produce + the same field element in `ZMod 3329` (see `Spec.zeta_at_one_eq_layer_7`) + but differ structurally; we target the spec actually computed by the + impl, `Spec.ntt_pure_vec_u`. -/ +@[spec] +theorem ntt_vector_u_fc + (VECTOR_U_COMPRESSION_FACTOR : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_vector_u + VECTOR_U_COMPRESSION_FACTOR + (vectortraitsOperationsInst := portable_ops_inst) re scratch + ⦃ ⇓ p => ⌜ lift_poly p.1 = Spec.ntt_pure_vec_u (lift_poly re) ⌝ ⦄ := by + -- Strategy: mirror L3.3, but use `ntt_at_layer_4_plus_portable_fc_strong` + -- with `zeta_i = 0, layer = 7` for the FIRST step (impl uses layer_4_plus + -- at layer 7, not the dedicated layer_7), and target `Spec.ntt_pure_vec_u`. + -- ============================================================= + -- Step 1: layer_4_plus (zeta_i=0, layer=7, bnd=3328). ≤ 3328 → ≤ 6656. + -- ============================================================= + obtain ⟨⟨zeta_i1, re1, scratch1⟩, h1_eq, h1_fc, h1_zout, h1_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_4_plus_portable_fc_strong 0#usize re 7#usize scratch 3328#usize + (by decide) (by decide) (by decide) h_bnd) + dsimp only at h1_fc h1_zout h1_bnd + have h_zeta_i1 : zeta_i1.val = 1 := by rw [h1_zout]; decide + -- ============================================================= + -- Step 2: usize_mul 2 * 3328 = 6656. + -- ============================================================= + obtain ⟨i6656, hi6656_eq, hi6656_val⟩ := + usize_mul_ok_eq_fc 2#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 3: layer_4_plus (zeta_i1=1, layer=6, bnd=6656). ≤ 6656 → ≤ 9984. + -- ============================================================= + have h_re1_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re1.coefficients.val[i]!).elements.val[j]!).val.natAbs + ≤ i6656.val := by + intro i hi j hj + have hb := h1_bnd i hi j hj + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + have h2 : (2#usize : Std.Usize).val = 2 := by decide + rw [h3328] at hb + rw [hi6656_val, h2, h3328] + omega + have h_i6656_bnd : i6656.val ≤ 8 * 3328 := by + have h := hi6656_val + have h2 : (2#usize : Std.Usize).val = 2 := by decide + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + rw [h2, h3328] at h + omega + obtain ⟨⟨zeta_i2, re2, scratch2⟩, h3_eq, h3_fc, h3_zout, h3_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_4_plus_portable_fc_strong zeta_i1 re1 6#usize scratch1 i6656 + (by decide) h_i6656_bnd + (by rw [h_zeta_i1]; decide) h_re1_loose) + dsimp only at h3_fc h3_zout h3_bnd + have h_zeta_i2 : zeta_i2.val = 3 := by + rw [h3_zout, h_zeta_i1]; decide + -- ============================================================= + -- Step 4: usize_mul 3 * 3328 = 9984. + -- ============================================================= + obtain ⟨i9984, hi9984_eq, hi9984_val⟩ := + usize_mul_ok_eq_fc 3#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 5: layer_4_plus (zeta_i2=3, layer=5, bnd=9984). ≤ 9984 → ≤ 13312. + -- ============================================================= + have h_re2_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re2.coefficients.val[i]!).elements.val[j]!).val.natAbs + ≤ i9984.val := by + intro i hi j hj + have hb := h3_bnd i hi j hj + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + have h2 : (2#usize : Std.Usize).val = 2 := by decide + have h3 : (3#usize : Std.Usize).val = 3 := by decide + rw [hi6656_val, h2, h3328] at hb + rw [hi9984_val, h3, h3328] + omega + have h_i9984_bnd : i9984.val ≤ 8 * 3328 := by + have h := hi9984_val + have h3 : (3#usize : Std.Usize).val = 3 := by decide + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + rw [h3, h3328] at h + omega + obtain ⟨⟨zeta_i3, re3, scratch3⟩, h5_eq, h5_fc, h5_zout, h5_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_4_plus_portable_fc_strong zeta_i2 re2 5#usize scratch2 i9984 + (by decide) h_i9984_bnd + (by rw [h_zeta_i2]; decide) h_re2_loose) + dsimp only at h5_fc h5_zout h5_bnd + have h_zeta_i3 : zeta_i3.val = 7 := by + rw [h5_zout, h_zeta_i2]; decide + -- ============================================================= + -- Step 6: usize_mul 4 * 3328 = 13312. + -- ============================================================= + obtain ⟨i13312, hi13312_eq, hi13312_val⟩ := + usize_mul_ok_eq_fc 4#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 7: layer_4_plus (zeta_i3=7, layer=4, bnd=13312). ≤ 13312 → ≤ 16640. + -- ============================================================= + have h_re3_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re3.coefficients.val[i]!).elements.val[j]!).val.natAbs + ≤ i13312.val := by + intro i hi j hj + have hb := h5_bnd i hi j hj + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + have h3 : (3#usize : Std.Usize).val = 3 := by decide + have h4 : (4#usize : Std.Usize).val = 4 := by decide + rw [hi9984_val, h3, h3328] at hb + rw [hi13312_val, h4, h3328] + omega + have h_i13312_bnd : i13312.val ≤ 8 * 3328 := by + have h := hi13312_val + have h4 : (4#usize : Std.Usize).val = 4 := by decide + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + rw [h4, h3328] at h + omega + obtain ⟨⟨zeta_i4, re4, scratch4⟩, h7_eq, h7_fc, h7_zout, h7_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_4_plus_portable_fc_strong zeta_i3 re3 4#usize scratch3 i13312 + (by decide) h_i13312_bnd + (by rw [h_zeta_i3]; decide) h_re3_loose) + dsimp only at h7_fc h7_zout h7_bnd + have h_zeta_i4 : zeta_i4.val = 15 := by + rw [h7_zout, h_zeta_i3]; decide + -- ============================================================= + -- Step 8: usize_mul 5 * 3328 = 16640. + -- ============================================================= + obtain ⟨i16640, hi16640_eq, hi16640_val⟩ := + usize_mul_ok_eq_fc 5#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 9: layer_3 (zeta_i4=15, bnd=16640 Nat). → ≤ 19968. zeta_out=31. + -- ============================================================= + have h_re4_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re4.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 16640 := by + intro i hi j hj + have hb := h7_bnd i hi j hj + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + have h4 : (4#usize : Std.Usize).val = 4 := by decide + rw [hi13312_val, h4, h3328] at hb + omega + obtain ⟨⟨zeta_i5, re5⟩, h9_eq, h9_fc, h9_zout, h9_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_3_portable_fc_strong zeta_i4 re4 i16640 16640 + (by decide) h_zeta_i4 h_re4_loose) + dsimp only at h9_fc h9_zout h9_bnd + -- ============================================================= + -- Step 10: usize_mul 6 * 3328 = 19968. + -- ============================================================= + obtain ⟨i19968, hi19968_eq, hi19968_val⟩ := + usize_mul_ok_eq_fc 6#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 11: layer_2 (zeta_i5=31, bnd=19968 Nat). → ≤ 23296. zeta_out=63. + -- ============================================================= + have h_re5_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re5.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 19968 := by + intro i hi j hj + have hb := h9_bnd i hi j hj + omega + obtain ⟨⟨zeta_i6, re6⟩, h11_eq, h11_fc, h11_zout, h11_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_2_portable_fc_strong zeta_i5 re5 i19968 19968 + (by decide) h9_zout h_re5_loose) + dsimp only at h11_fc h11_zout h11_bnd + -- ============================================================= + -- Step 12: usize_mul 7 * 3328 = 23296. + -- ============================================================= + obtain ⟨i23296, hi23296_eq, hi23296_val⟩ := + usize_mul_ok_eq_fc 7#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 13: layer_1 (zeta_i6=63, bnd=23296 Nat). → ≤ 26624. zeta_out=127. + -- ============================================================= + have h_re6_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re6.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 23296 := by + intro i hi j hj + have hb := h11_bnd i hi j hj + omega + obtain ⟨⟨_zeta_i7, re7⟩, h13_eq, h13_fc, _h13_zout, h13_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_1_portable_fc_strong zeta_i6 re6 i23296 23296 + (by decide) h11_zout h_re6_loose) + dsimp only at h13_fc h13_bnd + -- ============================================================= + -- Step 14: poly_barrett_reduce. ≤ 26624 ≤ 32767 → canonical residue. + -- ============================================================= + have h_re7_loose : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re7.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro chunk hc ℓ hℓ + have hb := h13_bnd chunk hc ℓ hℓ + omega + obtain ⟨re8, h14_eq, h14_fc⟩ := + triple_exists_ok_fc (poly_barrett_reduce_fc re7 h_re7_loose) + -- ============================================================= + -- Compose: derive the full impl `do`-block equation by simp-folding + -- all step equations into the unfolded body. + -- ============================================================= + have h_body : + libcrux_iot_ml_kem.ntt.ntt_vector_u + VECTOR_U_COMPRESSION_FACTOR + (vectortraitsOperationsInst := portable_ops_inst) re scratch + = .ok (re8, scratch4) := by + unfold libcrux_iot_ml_kem.ntt.ntt_vector_u + simp [h1_eq, h3_eq, h5_eq, h7_eq, h9_eq, h11_eq, h13_eq, h14_eq, + hi6656_eq, hi9984_eq, hi13312_eq, + hi16640_eq, hi19968_eq, hi23296_eq] + apply triple_of_ok_fc h_body + -- ============================================================= + -- Prove lift_poly equation by chaining FC equations through Spec.ntt_pure_vec_u. + -- ============================================================= + show lift_poly re8 = Spec.ntt_pure_vec_u (lift_poly re) + unfold Spec.ntt_pure_vec_u + -- Bridge barrett: h14_fc : poly_barrett_reduce (lift_poly re7) = .ok (lift_poly re8). + have hB_bridge : + hacspec_ml_kem.polynomial.poly_barrett_reduce (lift_poly re7) + = .ok (Spec.Pure.polynomial.poly_barrett_reduce_pure (lift_poly re7)) := + Spec.Pure.polynomial.poly_barrett_reduce_eq_ok (lift_poly re7) + rw [hB_bridge] at h14_fc + have h_re8_eq : lift_poly re8 + = Spec.Pure.polynomial.poly_barrett_reduce_pure (lift_poly re7) := by + have h := h14_fc + exact (Aeneas.Std.Result.ok.injEq _ _).mp h.symm + -- zeta_i identifications: substitute zeta values into the spec chain via .val. + have h_zeta_eq1 : zeta_i1 = 1#usize := by + have := h_zeta_i1; scalar_tac + have h_zeta_eq2 : zeta_i2 = 3#usize := by + have := h_zeta_i2; scalar_tac + have h_zeta_eq3 : zeta_i3 = 7#usize := by + have := h_zeta_i3; scalar_tac + have h_zeta_eq4 : zeta_i4 = 15#usize := by + have := h_zeta_i4; scalar_tac + have h_zeta_eq5 : zeta_i5 = 31#usize := by + have := h9_zout; scalar_tac + have h_zeta_eq6 : zeta_i6 = 63#usize := by + have := h11_zout; scalar_tac + rw [h_re8_eq, h13_fc, h11_fc, h9_fc, h7_fc, h5_fc, h3_fc, h1_fc, + h_zeta_eq1, h_zeta_eq2, h_zeta_eq3, h_zeta_eq4, h_zeta_eq5, h_zeta_eq6] + +/-! ### L6.2.A — Loop scaffolding for `subtract_reduce_fc`. + + Strengthened FC invariant for the 16-iter chunk-loop. Each iteration + `i ∈ 0..16` applies the fused chain + `b[i] := barrett (negate ((mont_mul b[i] 1441) - self[i]))` + leaving `b[j]` for `j ≠ i` untouched (and `self` immutable throughout). + + The chunk-level closure for chunk `i` is + `lift_chunk b'[i] = Spec.chunk_subtract_reduce_pure + (lift_chunk self[i]) (lift_chunk b[i])` + where `Spec.chunk_subtract_reduce_pure` (defined in §0.5) is the + 16-lane version of `self - b * lift_fe_mont(1441)`. -/ + + +end libcrux_iot_ml_kem.InvertNtt diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/Common.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/Common.lean new file mode 100644 index 00000000..38694e16 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/Common.lean @@ -0,0 +1,419 @@ +/- + # `Matrix/Common.lean` — shared L7.4 scaffolding. + + Holds the small shared definitions used by the L7.4 `compute_message` + proof and (prospectively) reused by L7.2/L7.3: + + * `Impl.compute_message_zero` — the all-zero canonical-domain poly used + as the accumulator fold seed (mirrors the impl's `accumulator1 := + Array.repeat 256 (classify 0)` re-zero at `matrix.rs:96`). + + These live above the `Impl.*_pure` mirror (defined in + `L7/Impl/ComputeMessage.lean`) and the bridge lemmas (in + `Matrix/ComputeMessage/Hacspec.lean`). + + SKELETON — no proofs beyond what is needed for these defs to + elaborate. The named obligations live in the Impl/Hacspec/FC files. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc + +namespace libcrux_iot_ml_kem.Matrix.Common +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +/-- The all-zero canonical-domain ring element (256 lanes, each + `FieldElement.val = 0`). This is the fold seed for + `Impl.compute_message_acc_pure`, mirroring the impl's explicit + accumulator re-zero (`matrix.rs:96`, modeled in `` as + `Array.repeat 256#usize (classify 0#i32)`). -/ +noncomputable def Impl.compute_message_zero : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Std.Array.make 256#usize + ((List.range 256).map (fun _ => ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement))) + (by rw [List.length_map, List.length_range]; rfl) + +/-! ## Mont→canonical BRIDGE for the `reducing_from_i32_array` step. + + `poly_reducing_from_i32_array_fc` characterizes its output `result1` + in the `lift_poly_mont` domain (`lift_poly_mont result1 = …pure`), but + `invert_ntt_montgomery_fc` consumes `result1` in the `lift_poly` domain + (`lift_poly result1`). The two differ by one Montgomery factor `R` per + lane. The impl's `montgomery_multiply_by_constant 1353` (= `R²` mod q) + convention used by the L6.3a finalizer means the canonical lane value + is `mul_pure (mont-lane) (lift_fe_mont 1353)` (since + `1353 ≡ R² (mod q)` and `lift_fe_mont` carries an `R⁻¹`, the product is + `(a·R⁻¹)·(R²·R⁻¹) = a`). + + `Impl.mont_strip_pure` is the poly-level BRIDGE; the bridge lemma + `Impl.mont_strip_lift_poly_mont_eq_lift_poly` re-derives FCTargets' + `private lift_poly_mont_to_lift_poly` from public primitives so the L7 + files (which cannot see the private original) can apply it. -/ + +/-- Local copy of `Spec.Pure.uscalar_rem_ok_U32` (private there); the L7 + files re-derive it from `BitVec.umod` to reprove `mul_pure_val_eq`. -/ +private theorem Impl.uscalar_rem_ok_U32 (z m : Std.U32) (hm : m.val ≠ 0) : + ∃ w : Std.U32, (z % m : Result Std.U32) = .ok w ∧ w.val = z.val % m.val := by + have heq : (z % m : Result Std.U32) = Std.UScalar.rem z m := rfl + unfold Std.UScalar.rem at heq + simp [hm] at heq + refine ⟨_, heq, ?_⟩ + show (BitVec.umod z.bv m.bv).toNat = z.val % m.val + unfold BitVec.umod + simp only [BitVec.toNat_ofNatLT] + rfl + +/-- Local copy of FCTargets' `private mul_pure_val_eq`: + `(mul_pure a b).val.val = (a.val.val * b.val.val) % 3329`, + unconditional (the U32 widening keeps the product `< 2^32`). -/ +private theorem Impl.mul_pure_val_eq + (a b : hacspec_ml_kem.parameters.FieldElement) : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b).val.val + = (a.val.val * b.val.val) % 3329 := by + have hmul : + hacspec_ml_kem.parameters.FieldElement.mul a b + = .ok (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b) := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_eq_ok a b + unfold hacspec_ml_kem.parameters.FieldElement.mul at hmul + simp only [Aeneas.Std.lift, Aeneas.Std.bind_tc_ok] at hmul + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Aeneas.Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hae := Std.UScalar.mul_equiv x y + have heqmul : (x * y : Result Std.U32) = Std.UScalar.mul x y := rfl + cases hxy : (x * y : Result Std.U32) with + | ok z => + rw [hxy] at hmul + rw [heqmul] at hxy; rw [hxy] at hae; simp at hae + obtain ⟨_, hzval, _⟩ := hae + simp only [Aeneas.Std.bind_tc_ok] at hmul + have hmod_val : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; simp + have hmod_ne : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val ≠ 0 := by + rw [hmod_val]; decide + set m : Std.U32 := Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS + obtain ⟨w, hw_eq, hwval⟩ := Impl.uscalar_rem_ok_U32 z m hmod_ne + rw [hw_eq] at hmul; simp only [Aeneas.Std.bind_tc_ok] at hmul + unfold hacspec_ml_kem.parameters.FieldElement.new at hmul + simp at hmul + have hwbnd : w.val < 3329 := by + rw [hwval, hmod_val]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Aeneas.Std.UScalarTy.numBits]; omega + rw [← hmul] + show (Std.UScalar.cast .U16 w).val = (a.val.val * b.val.val) % 3329 + rw [hwcast, hwval, hmod_val, hzval, hxval, hyval] + | fail _ => + rw [heqmul] at hxy; rw [hxy] at hae + simp only [Std.UScalar.max, Aeneas.Std.UScalarTy.numBits] at hae + rw [hxval, hyval] at hae + have : a.val.val * b.val.val < 2^32 := by + have h1 : a.val.val * b.val.val ≤ (2^16 - 1) * (2^16 - 1) := by + apply Nat.mul_le_mul <;> omega + have heq : (2^16 - 1) * (2^16 - 1) = 2^32 - 2*2^16 + 1 := by decide + omega + omega + | div => rw [heqmul] at hxy; rw [hxy] at hae; exact hae.elim + +/-- `zmodOfFE` distributes over `mul_pure` (public re-derivation of + FCTargets' `private L2_8c.zmodOfFE_mul_pure`). -/ +private theorem Impl.zmodOfFE_mul_pure + (a b : hacspec_ml_kem.parameters.FieldElement) : + zmodOfFE (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b) + = zmodOfFE a * zmodOfFE b := by + unfold zmodOfFE + rw [Impl.mul_pure_val_eq] + rw [ZMod.natCast_mod] + push_cast + rfl + +/-- `zmodOfFE (lift_fe_mont x) = x.val · 169` (public re-derivation of + FCTargets' `private L2_8c.zmodOfFE_lift_fe_mont`). -/ +private theorem Impl.zmodOfFE_lift_fe_mont (x : Std.I16) : + zmodOfFE (lift_fe_mont x) = (x.val : ZMod 3329) * 169 := by + unfold lift_fe_mont + rw [zmodOfFE_feOfZMod] + rfl + +/-- FE-level Mont→canonical bridge: + `mul_pure (lift_fe_mont x) (lift_fe_mont 1353) = lift_fe x`. + In `ZMod 3329`: `lift_fe_mont y = y·169` and `1353·169·169 ≡ 1`, so the + product canonically round-trips to `x`. Reproves the `private` + `lift_fe_mont_mul_1353_eq_lift_fe` from public lemmas. -/ +theorem Impl.lift_fe_mont_mul_1353_eq_lift_fe (x : Std.I16) : + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont x) (lift_fe_mont (1353#i16 : Std.I16)) + = lift_fe x := by + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont x) (lift_fe_mont (1353#i16 : Std.I16)) with hs_def + -- (1) `s` is canonical (`Canonical_mul_pure` is unconditional). + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure + (lift_fe_mont x) (lift_fe_mont (1353#i16 : Std.I16)) + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + -- (2) Canonical round-trip `feOfZMod (zmodOfFE s) = s`. + have h_round_trip : feOfZMod (zmodOfFE s) = s := by + unfold feOfZMod zmodOfFE + have hzval : ((s.val.val : ZMod 3329)).val = s.val.val := + ZMod.val_natCast_of_lt h_canon + rw [hzval] + have hsval : s.val.val < 2 ^ 16 := by + have h_p : (3329 : Nat) ≤ 2 ^ 16 := by decide + omega + have hsbv : BitVec.ofNat 16 s.val.val = s.val.bv := by + apply BitVec.eq_of_toNat_eq + rw [BitVec.toNat_ofNat] + show s.val.val % 2 ^ 16 = s.val.bv.toNat + rw [Nat.mod_eq_of_lt hsval]; rfl + show ({ val := ⟨BitVec.ofNat 16 s.val.val⟩ } : + hacspec_ml_kem.parameters.FieldElement) = s + rw [hsbv] + -- (3) `zmodOfFE s = (x.val : ZMod 3329)`. + have h_zmod_s : zmodOfFE s = ((x.val : Int) : ZMod 3329) := by + rw [hs_def, Impl.zmodOfFE_mul_pure, + Impl.zmodOfFE_lift_fe_mont, Impl.zmodOfFE_lift_fe_mont] + have h_1353 : (((1353#i16 : Std.I16).val : Int) : ZMod 3329) = 1353 := by + decide + rw [h_1353] + have h_inv : (169 : ZMod 3329) * (1353 * 169) = 1 := by decide + calc ((x.val : Int) : ZMod 3329) * 169 * (1353 * 169) + = ((x.val : Int) : ZMod 3329) * (169 * (1353 * 169)) := by ring + _ = ((x.val : Int) : ZMod 3329) * 1 := by rw [h_inv] + _ = ((x.val : Int) : ZMod 3329) := by ring + -- (4) Glue: `s = feOfZMod (zmodOfFE s) = lift_fe x`. + show s = lift_fe x + rw [← h_round_trip, h_zmod_s] + unfold lift_fe i16_to_spec_fe_plain + rfl + +/-- Poly-level Mont→canonical bridge function. Maps each of the 256 lanes + through `mul_pure · (lift_fe_mont 1353)`. -/ +noncomputable def Impl.mont_strip_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Std.Array.make 256#usize + ((List.range 256).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (p.val[i]!) (lift_fe_mont (1353#i16 : Std.I16)))) + (by simp) + +/-- Poly-level Mont→canonical BRIDGE law: + `mont_strip_pure (lift_poly_mont re) = lift_poly re`. + Reproves FCTargets' `private lift_poly_mont_to_lift_poly` (poly form) + from the FE-level helper. This is the lemma that lets S2 connect the + `reducing_from_i32_array` POST (stated via `lift_poly_mont`) to the + `invert_ntt_montgomery` PRE (stated via `lift_poly`). -/ +theorem Impl.mont_strip_lift_poly_mont_eq_lift_poly + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Impl.mont_strip_pure (libcrux_iot_ml_kem.Spec.Lift.lift_poly_mont re) = lift_poly re := by + unfold Impl.mont_strip_pure + apply Subtype.ext + show (List.range 256).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((libcrux_iot_ml_kem.Spec.Lift.lift_poly_mont re).val[i]!) (lift_fe_mont (1353#i16 : Std.I16))) + = (lift_poly re).val + unfold lift_poly + show (List.range 256).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((libcrux_iot_ml_kem.Spec.Lift.lift_poly_mont re).val[i]!) (lift_fe_mont (1353#i16 : Std.I16))) + = (List.range 256).map (fun j => + lift_fe (re.coefficients.val[j / 16]!).elements.val[j % 16]!) + apply List.ext_getElem + · simp + · intro j hj1 _hj2 + have hj : j < 256 := by + have : j < ((List.range 256).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((libcrux_iot_ml_kem.Spec.Lift.lift_poly_mont re).val[i]!) (lift_fe_mont (1353#i16 : Std.I16)))).length := hj1 + simpa using this + simp only [List.getElem_map, List.getElem_range] + -- LHS lane = mul_pure (lift_fe_mont x) (lift_fe_mont 1353); RHS = lift_fe x. + set x : Std.I16 := + (re.coefficients.val[j / 16]!).elements.val[j % 16]! with hx_def + have h_mont : (libcrux_iot_ml_kem.Spec.Lift.lift_poly_mont re).val[j]! = lift_fe_mont x := by + unfold libcrux_iot_ml_kem.Spec.Lift.lift_poly_mont + show ((List.range 256).map (fun k => + lift_fe_mont (re.coefficients.val[k / 16]!).elements.val[k % 16]!))[j]! + = lift_fe_mont x + have h_len : ((List.range 256).map (fun k => + lift_fe_mont (re.coefficients.val[k / 16]!).elements.val[k % 16]!)).length = 256 := by + simp + rw [getElem!_pos _ j (by rw [h_len]; exact hj)] + rw [List.getElem_map, List.getElem_range] + rw [h_mont] + exact Impl.lift_fe_mont_mul_1353_eq_lift_fe x + +end libcrux_iot_ml_kem.Matrix.Common +/-! ### Extracted from FCTargets.lean (§matrix_entry). -/ + +namespace libcrux_iot_ml_kem.Matrix.Common +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §L6.8 — matrix per-cell accessor. + + `matrix.entry K matrix i j` is a pure indexing op on a flat K·K slice + of polynomial-ring elements. The FC equation lifts the result via + `lift_poly` and matches the (i, j)-th matrix entry, which under + `lift_matrix_from_slice`'s column-major convention is accessed as + `L.val[j.val]!.val[i.val]!` (outer = column, inner = row). Uses the + file-scoped `Inhabited` instances `instInhabitedFEPoly_fcTargets` and + `instInhabitedFEPolyVec_fcTargets` (declared next to + `instInhabitedFEChunk_fcTargets`). -/ + +/-- Pure-projection side lemma for `matrix.entry`. Reduces the impl `do`-block + to a single `Slice.index_usize` at row-major offset `i.val * K.val + j.val`, + under the canonical preconditions `matrix.length = K·K`, `i < K`, `j < K`. + Named without a `matrix.` prefix to avoid Lean's dot-notation projection + being triggered when a local variable `matrix` is in scope at the call + site (in `matrix.entry_fc`). -/ +theorem entry_eq_ok_fc_aux + (K : Std.Usize) + (matrix : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (i j : Std.Usize) + (h_len : matrix.val.length = K.val * K.val) + (h_i : i.val < K.val) (h_j : j.val < K.val) : + libcrux_iot_ml_kem.matrix.entry K portable_ops_inst matrix i j + = .ok (matrix.val[i.val * K.val + j.val]!) := by + -- Slice invariant + h_len combine to give the arithmetic bounds. + have h_slice_max : matrix.val.length ≤ Std.Usize.max := matrix.property + have h_KK_max : K.val * K.val ≤ Std.Usize.max := by rw [← h_len]; exact h_slice_max + -- K.val > 0 because i.val < K.val. + have h_K_pos : 0 < K.val := Nat.lt_of_le_of_lt (Nat.zero_le _) h_i + -- i.val * K.val ≤ (K.val - 1) * K.val < K.val * K.val. + have h_iK_lt : i.val * K.val < K.val * K.val := + (Nat.mul_lt_mul_right h_K_pos).mpr h_i + have h_iK_max : i.val * K.val ≤ Std.Usize.max := by + apply le_trans (Nat.le_of_lt h_iK_lt) h_KK_max + -- i.val * K.val + j.val < K.val * K.val. + have h_idx_lt_KK : i.val * K.val + j.val < K.val * K.val := by + have : i.val * K.val + j.val < i.val * K.val + K.val := Nat.add_lt_add_left h_j _ + have h_step : i.val * K.val + K.val ≤ K.val * K.val := by + have : (i.val + 1) * K.val ≤ K.val * K.val := + Nat.mul_le_mul_right _ h_i + have h_expand : (i.val + 1) * K.val = i.val * K.val + K.val := by ring + rw [h_expand] at this; exact this + omega + have h_idx_max : i.val * K.val + j.val ≤ Std.Usize.max := by + apply le_trans (Nat.le_of_lt h_idx_lt_KK) h_KK_max + have h_idx_lt_len : i.val * K.val + j.val < matrix.val.length := by + rw [h_len]; exact h_idx_lt_KK + -- Now reduce the do-block step by step. + unfold libcrux_iot_ml_kem.matrix.entry + -- Step 1: `core.slice.Slice.len matrix` = `.ok matrix.len`. + unfold core.slice.Slice.len + -- Step 2: `K * K` = `.ok` of a Usize with val = K.val * K.val. + obtain ⟨kk, h_kk_eq, h_kk_val⟩ := usize_mul_ok_eq_fc K K h_KK_max + -- Step 3: `i * K` = `.ok` of a Usize with val = i.val * K.val. + obtain ⟨ik, h_ik_eq, h_ik_val⟩ := usize_mul_ok_eq_fc i K h_iK_max + -- Step 4: `ik + j` = `.ok` of a Usize with val = i.val * K.val + j.val. + have h_ikj_max : ik.val + j.val ≤ Std.Usize.max := by rw [h_ik_val]; exact h_idx_max + obtain ⟨idx, h_idx_eq, h_idx_val⟩ := usize_add_ok_eq_fc ik j h_ikj_max + -- Massert preconditions. + have h_massert_len : (Aeneas.Std.Slice.len matrix : Std.Usize) = kk := by + apply Std.UScalar.eq_of_val_eq + show matrix.val.length = kk.val + rw [h_kk_val, h_len] + -- Slice.index_usize at idx returns matrix.val[idx.val]!. + have h_idx_lt_matrix : idx.val < matrix.val.length := by + rw [h_idx_val, h_ik_val]; exact h_idx_lt_len + have h_slice_idx : + Aeneas.Std.Slice.index_usize matrix idx = .ok (matrix.val[idx.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq matrix idx h_idx_lt_matrix + -- Rewrite the do-block. + simp only [pure, Pure.pure, Aeneas.Std.bind_tc_ok, h_kk_eq, h_ik_eq, h_idx_eq, + h_slice_idx, h_massert_len] + -- Discharge massert (i = kk equality), massert (i < K), massert (j < K). + unfold Aeneas.Std.massert + have h_i_K : i < K := (Std.UScalar.lt_equiv i K).mpr h_i + have h_j_K : j < K := (Std.UScalar.lt_equiv j K).mpr h_j + simp only [if_true, Aeneas.Std.bind_tc_ok, h_i_K, h_j_K] + -- Final goal: matrix.val[idx.val]! = matrix.val[i.val * K.val + j.val]!. + rw [h_idx_val, h_ik_val] + +/-- L6.8 — `matrix.entry`: row-major access of a flat K·K poly slice. + The FC equation says the impl's returned `PolynomialRingElement` + lifts (via `lift_poly`) to the `(i, j)`-th matrix entry, accessed + under `lift_matrix_from_slice`'s column-major convention as + `L.val[j.val]!.val[i.val]!` (outer = column, inner = row). -/ +@[spec] +theorem matrix.entry_fc + (K : Std.Usize) + (matrix : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (i j : Std.Usize) + (h_len : matrix.val.length = K.val * K.val) + (h_i : i.val < K.val) (h_j : j.val < K.val) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.entry K portable_ops_inst matrix i j + ⦃ ⇓ r => ⌜ lift_poly r = (lift_matrix_from_slice matrix K).val[j.val]!.val[i.val]! ⌝ ⦄ := by + apply triple_of_ok_fc (entry_eq_ok_fc_aux K matrix i j h_len h_i h_j) + -- Goal: lift_poly matrix.val[i.val * K.val + j.val]! + -- = (lift_matrix_from_slice matrix K).val[j.val]!.val[i.val]! + -- Reduce the matrix lift's nested `Std.Array.make` constructions explicitly. + -- `(Std.Array.make n init _).val = init` (definitional), so each outer + -- `.val[idx]!` collapses to a `List`-indexing on the inner `init` list. + -- Under the column-major convention, outer index = j (column), + -- inner index = i (row). + unfold lift_matrix_from_slice + -- Outer-list index reduction (outer index = column `j`). + have h_range_len : (List.range K.val).length = K.val := by simp + have h_outer_len : ((List.range K.val).map (fun j' => + Std.Array.make K ((List.range K.val).map (fun i' => + lift_poly matrix.val[i' * K.val + j']!)) (by simp))).length = K.val := by + rw [List.length_map, h_range_len] + have h_j_lt_outer : j.val < ((List.range K.val).map (fun j' => + Std.Array.make K ((List.range K.val).map (fun i' => + lift_poly matrix.val[i' * K.val + j']!)) (by simp))).length := by + rw [h_outer_len]; exact h_j + -- Use `Std.Array.make`'s definitional `.val = init` to expose the outer list, + -- then resolve the outer index via `getElem!_pos`. + show lift_poly matrix.val[i.val * K.val + j.val]! + = ((((List.range K.val).map (fun j' => + Std.Array.make K ((List.range K.val).map (fun i' => + lift_poly matrix.val[i' * K.val + j']!)) (by simp)))[j.val]!).val[i.val]!) + rw [getElem!_pos _ j.val h_j_lt_outer] + rw [List.getElem_map, List.getElem_range] + -- The outer `(fun j' => Std.Array.make K ... _) j.val` β-reduces to + -- `Std.Array.make K (...) _`; its `.val` is the inner list. + show lift_poly matrix.val[i.val * K.val + j.val]! + = ((List.range K.val).map (fun i' => + lift_poly matrix.val[i' * K.val + j.val]!))[i.val]! + have h_inner_len : ((List.range K.val).map (fun i' => + lift_poly matrix.val[i' * K.val + j.val]!)).length = K.val := by + rw [List.length_map, h_range_len] + have h_i_lt_inner : i.val < ((List.range K.val).map (fun i' => + lift_poly matrix.val[i' * K.val + j.val]!)).length := by + rw [h_inner_len]; exact h_i + rw [getElem!_pos _ i.val h_i_lt_inner] + rw [List.getElem_map, List.getElem_range] + + +end libcrux_iot_ml_kem.Matrix.Common \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeAsPlusE.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeAsPlusE.lean new file mode 100644 index 00000000..cf2d8189 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeAsPlusE.lean @@ -0,0 +1,4103 @@ +/- + # `Matrix/ComputeAsPlusE.lean` — extracted from `FCTargets.lean` §compute_as_plus_e. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.Matrix.ComputeAsPlusE +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §L7-prep — Mont→canonical bridge. + + Single lemma `lift_poly_mont_to_lift_poly` used by L7.1's outer-loop + composition. The L6 forward-deps for matrix-row accumulation + (`accumulating_ntt_multiply_*_poly_fc` + `add_standard_error_reduce_fc`) + produce per-lane outputs in Mont form (`lift_poly_mont`), but L7.1's + POST consumes them through hacspec's `add_polynomials` after the row + finalizer, which expects canonical-domain values. The L6.3a finalizer + calls `montgomery_multiply_by_constant 1353` (= `R²` mod q) to strip + one R per lane, so the bridge takes the form of a `mul_pure` against + `lift_fe_mont 1353` collapsing to `lift_fe`. -/ + +/-- FE-level Mont→canonical bridge: `mul_pure (lift_fe_mont x) (lift_fe_mont 1353) + = lift_fe x`. In `ZMod 3329`, `lift_fe_mont x = x · 169` (where + `169 = R⁻¹ mod q`), so the LHS reduces via `zmodOfFE_mul_pure + + zmodOfFE_lift_fe_mont` to `(x · 169) · (1353 · 169) = x · (169² · 1353)`. + The keystone gives `1353 ≡ R² (mod q)` and `R⁻¹ = 169`, so + `169² · 1353 ≡ R⁻² · R² = 1`. Canonical round-trip closes. -/ +lemma lift_fe_mont_mul_1353_eq_lift_fe (x : Std.I16) : + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont x) (lift_fe_mont (1353#i16 : Std.I16)) + = lift_fe x := by + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont x) (lift_fe_mont (1353#i16 : Std.I16)) with hs_def + -- (1) `s` is canonical (Canonical_mul_pure unconditional). + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure + (lift_fe_mont x) (lift_fe_mont (1353#i16 : Std.I16)) + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + -- (2) Canonical round-trip. + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + -- (3) `zmodOfFE s = (x.val : ZMod 3329)`. + have h_zmod_s : zmodOfFE s = ((x.val : Int) : ZMod 3329) := by + rw [hs_def, L2_8c.zmodOfFE_mul_pure, + L2_8c.zmodOfFE_lift_fe_mont, L2_8c.zmodOfFE_lift_fe_mont] + -- Goal: (x.val : ZMod 3329) * 169 * (((1353#i16).val : ZMod 3329) * 169) = (x.val : ZMod 3329) + have h_1353 : (((1353#i16 : Std.I16).val : Int) : ZMod 3329) = 1353 := by + decide + rw [h_1353] + -- Goal: (x.val : ZMod 3329) * 169 * (1353 * 169) = (x.val : ZMod 3329) + have h_inv : (169 : ZMod 3329) * (1353 * 169) = 1 := by decide + calc ((x.val : Int) : ZMod 3329) * 169 * (1353 * 169) + = ((x.val : Int) : ZMod 3329) * (169 * (1353 * 169)) := by ring + _ = ((x.val : Int) : ZMod 3329) * 1 := by rw [h_inv] + _ = ((x.val : Int) : ZMod 3329) := by ring + -- (4) Glue: `s = feOfZMod (zmodOfFE s) = feOfZMod ((x.val : ZMod 3329)) = lift_fe x`. + show s = lift_fe x + rw [← h_round_trip, h_zmod_s] + unfold lift_fe i16_to_spec_fe_plain + rfl + +/-- Poly-level Mont→canonical bridge: at every lane in [0, 256), + `mul_pure (lift_poly_mont re).val[lane]! (lift_fe_mont 1353) + = (lift_poly re).val[lane]!`. + + Reduces to the FE-level helper `lift_fe_mont_mul_1353_eq_lift_fe` + after unfolding the two `lift_poly*` getters to their underlying + `lift_fe_mont`/`lift_fe` of the same I16 lane + `(re.coefficients.val[lane/16]!).elements.val[lane%16]!`. -/ +lemma lift_poly_mont_to_lift_poly + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (lane : Nat) (h_lane : lane < 256) : + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_poly_mont re).val[lane]!) (lift_fe_mont (1353#i16 : Std.I16)) + = (lift_poly re).val[lane]! := by + -- Pin the underlying I16 lane. + set x : Std.I16 := + (re.coefficients.val[lane / 16]!).elements.val[lane % 16]! with hx_def + -- (A) `(lift_poly_mont re).val[lane]! = lift_fe_mont x`. + have h_mont : (lift_poly_mont re).val[lane]! = lift_fe_mont x := by + unfold lift_poly_mont + show ((List.range 256).map (fun j => + lift_fe_mont (re.coefficients.val[j / 16]!).elements.val[j % 16]!))[lane]! + = lift_fe_mont x + have h_len : ((List.range 256).map (fun j => + lift_fe_mont (re.coefficients.val[j / 16]!).elements.val[j % 16]!)).length = 256 := by + simp + rw [getElem!_pos _ lane (by rw [h_len]; exact h_lane)] + rw [List.getElem_map, List.getElem_range] + -- (B) `(lift_poly re).val[lane]! = lift_fe x`. + have h_plain : (lift_poly re).val[lane]! = lift_fe x := by + unfold lift_poly + show ((List.range 256).map (fun j => + lift_fe (re.coefficients.val[j / 16]!).elements.val[j % 16]!))[lane]! + = lift_fe x + have h_len : ((List.range 256).map (fun j => + lift_fe (re.coefficients.val[j / 16]!).elements.val[j % 16]!)).length = 256 := by + simp + rw [getElem!_pos _ lane (by rw [h_len]; exact h_lane)] + rw [List.getElem_map, List.getElem_range] + rw [h_mont, h_plain] + exact lift_fe_mont_mul_1353_eq_lift_fe x + +/-! ## §L7.1-loop0 — row-0 column loop scaffolding. + + Namespace `Stage1FillCacheFC` provides the invariant + step-post predicates + used to characterize `matrix.compute_As_plus_e_loop0` (the K-iteration + column loop for row 0) via `loop_range_spec_usize`. Each iteration + calls `accumulating_ntt_multiply_fill_cache` on column `j ∈ [0, K)`, + adding column j's contribution to the I32 accumulator AND populating + `s_cache.val[j]!`. The invariant tracks both effects across `k` + iterations. + + Mirrors `FillCacheFC` but at the row-axis K-scale + rather than the chunk-axis 16-scale. -/ + +namespace Stage1FillCacheFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +abbrev Acc := UseCacheFC.Acc +abbrev Poly := UseCacheFC.Poly + +/-- 4-conjunct invariant for the row-0 column loop. Tracks: + (1) accumulator characterization: for each chunk j and lane ℓ in + `[0, 16)²`, `Spec.mont_reduce_pure (lift_fe_int acc[16j+ℓ].val)` + equals init plus the canonical-form sum of column contributions + from columns `[0, k)`. + (2) accumulator bound: `|acc.val[n]| ≤ |acc_init.val[n]| + k · 2^25`. + (3) cache characterization: for each c ∈ `[0, k)`, + `cache.val[c]!.coefficients[j]!` (across all chunks j) stores the + per-chunk `ntt_multiply_cache_post` for `s_as_ntt.val[c]!.coefficients[j]!`. + (4) cache unchanged: for each c ∈ `[k, K)`, `cache.val[c]! = cache_init.val[c]!`. -/ +def row0_inv {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Acc) + (cache_init : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) : + Std.Usize → Acc → + Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K → + Result Prop := + fun k acc cache => pure ( + -- (1) Per-(chunk j, lane ℓ) accumulator: canonical-form K-column sum. + (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range k.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + -- (2) Accumulator bound grows by 2^25 per column iteration. + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + k.val * 2^25) + -- (3) Cache populated for columns [0, k). + ∧ (∀ c : Nat, c < k.val → + accumulating_ntt_multiply_poly_cache_post + (s_as_ntt.val[c]!) (cache.val[c]!)) + -- (4) Cache unchanged for columns [k, K). + ∧ (∀ c : Nat, k.val ≤ c → c < K.val → + cache.val[c]! = cache_init.val[c]!)) + +/-- Step-post for `loop_range_spec_usize` over (acc, cache). -/ +def row0_step_post {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Acc) + (cache_init : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × + (Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) × + Acc) + (Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K × Acc)) : + Prop := + match r with + | .cont (iter', cache', acc') => + k.val < K.val ∧ iter'.«end» = K + ∧ iter'.start.val = k.val + 1 + ∧ (row0_inv matrix_A s_as_ntt acc_init cache_init iter'.start acc' cache').holds + | .done y => (row0_inv matrix_A s_as_ntt acc_init cache_init K y.2 y.1).holds + +end Stage1FillCacheFC + +-- Memory hygiene (rule 1 / SKILL §5.7 Idiom 2). Mirrors `L6_3c_fill_irreducible` +-- — heavy POST predicates and the per-column forward dep are +-- made locally irreducible across the step lemma + outer Triple so that +-- elaboration does not whnf-explode through the 4-conjunct `row0_inv` body or +-- the nested `∀ j : Fin 16, ∀ ℓ : Fin 16` accumulator characterization. +-- we do NOT mark +-- `Stage1FillCacheFC.row0_inv` / `row0_step_post` irreducible — keeping them reducible +-- preserves the `simpa`-based destructure of `h_inv`. +section L7_1a_irreducible +attribute [local irreducible] Spec.ntt_multiply_cache_post +attribute [local irreducible] accumulating_ntt_multiply_poly_cache_post +attribute [local irreducible] accumulating_ntt_multiply_poly_post +attribute [local irreducible] Spec.ntt_multiply_pure_no_acc +attribute [local irreducible] Spec.mont_reduce_pure + +set_option maxHeartbeats 16000000 in +set_option maxRecDepth 1000 in +/-- Per-iteration FC step lemma for the row-0 column loop. Given the + `row0_inv` invariant at step k and the strengthened PRE bounds, executing + one body iteration of `matrix.compute_As_plus_e_loop0.body` produces the + `row0_step_post` (either `.cont` advancing the invariant to k+1 or + `.done` capping at K). + + Mirrors `accumulating_ntt_multiply_fill_cache_poly_step_lemma_fc` but at the row-axis K-scale rather than the chunk-axis + 16-scale. Per-iteration step composes: + 1. `matrix.entry` reduction (via `entry_eq_ok_fc_aux`) at `(i, j) = (0, k)` + gives `matrix_A.val[0*K+k]! = matrix_A.val[k]!`. + 2. `array_index_usize_ok_eq` for `s_as_ntt[k]`. + 3. `Array.index_mut_usize` reduction for `s_cache[k]` (extract pre + + `cache.set k` setter). + 4. `accumulating_ntt_multiply_fill_cache_poly_fc` on + column k. The current accumulator's bound `≤ 2^30` follows from the + PRE budget `(acc_init[n]).val.natAbs + K·2^25 ≤ 2^30` combined with + invariant conjunct (2) `acc[n] ≤ acc_init[n] + k·2^25`. + 5. Splice the new cache chunk into `s_cache.set k _`. + 6. Re-establish `row0_inv` at k+1 (advancing the per-(j, ℓ) foldl by one + step via List.range_succ + List.foldl_append), or unchanged at k=K for + the `.done` branch. -/ +theorem compute_As_plus_e_loop0_step_lemma_fc + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Stage1FillCacheFC.Acc) + (cache_init : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) + (h_matrix_bnd : ∀ k : Fin matrix_A.length, ∀ i j : Fin 16, + ((matrix_A.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_s_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((s_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, + (acc_init.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) + (acc : Stage1FillCacheFC.Acc) + (cache : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (k : Std.Usize) (h_le : k.val ≤ K.val) + (h_inv : (Stage1FillCacheFC.row0_inv matrix_A s_as_ntt acc_init cache_init k acc cache).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt + { start := k, «end» := K } cache acc + ⦃ ⇓ r => ⌜ Stage1FillCacheFC.row0_step_post matrix_A s_as_ntt acc_init cache_init k r ⌝ ⦄ := by + have h_cache_len : cache.length = K.val := Std.Array.length_eq cache + have h_cache_init_len : cache_init.length = K.val := Std.Array.length_eq cache_init + have h_s_as_ntt_len : s_as_ntt.length = K.val := Std.Array.length_eq s_as_ntt + have h_acc_len : acc.length = 256 := Std.Array.length_eq acc + have h_acc_init_len : acc_init.length = 256 := Std.Array.length_eq acc_init + -- Destructure the 4-conjunct invariant. + obtain ⟨h_inv_acc, h_inv_acc_bnd, h_inv_cache_done, h_inv_cache_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop0.body + by_cases h_lt : k.val < K.val + · -- `Some k` branch. + -- (1) IteratorRange.next reduces to .ok (some k, { start := s_iter, end := K }). + have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, + ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- (2) matrix.entry reduces to .ok matrix_A.val[k.val]! (since i = 0, j = k). + have h_0K : (0#usize : Std.Usize).val < K.val := by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0]; omega + have h_matrix_entry : + libcrux_iot_ml_kem.matrix.entry K portable_ops_inst matrix_A 0#usize k + = .ok (matrix_A.val[(0#usize : Std.Usize).val * K.val + k.val]!) := + entry_eq_ok_fc_aux K matrix_A 0#usize k hAlen h_0K h_lt + have h_idx0 : (0#usize : Std.Usize).val * K.val + k.val = k.val := by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0]; omega + set t_matrix : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + matrix_A.val[k.val]! with ht_matrix_def + have h_matrix_entry' : + libcrux_iot_ml_kem.matrix.entry K portable_ops_inst matrix_A 0#usize k + = .ok t_matrix := by + rw [h_matrix_entry] + congr 1 + show matrix_A.val[(0#usize : Std.Usize).val * K.val + k.val]! = matrix_A.val[k.val]! + congr 1 + -- (3) Array.index_usize s_as_ntt k reduces to .ok s_as_ntt[k.val]!. + set t_s : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + s_as_ntt.val[k.val]! with ht_s_def + have h_idx_s : Aeneas.Std.Array.index_usize s_as_ntt k = .ok t_s := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq s_as_ntt k + (by rw [h_s_as_ntt_len]; exact h_lt) + -- (4) Array.index_mut_usize s_cache k splits into (s_cache[k]!, set). + set t_cache : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + cache.val[k.val]! with ht_cache_def + have h_idx_cache : Aeneas.Std.Array.index_usize cache k = .ok t_cache := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq cache k + (by rw [h_cache_len]; exact h_lt) + have h_imt_cache : Aeneas.Std.Array.index_mut_usize cache k + = .ok (t_cache, cache.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_cache]; rfl + -- (5) Apply L6.3c per-column forward dep at column k. + -- Per-lane bounds on t_matrix and t_s (16×16 lanes). + have hK_pos : 0 < K.val := Nat.lt_of_le_of_lt (Nat.zero_le _) h_lt + have h_k_lt_KK : k.val < K.val * K.val := by + calc k.val < K.val := h_lt + _ ≤ K.val * K.val := Nat.le_mul_of_pos_left K.val hK_pos + have h_k_lt_len : k.val < matrix_A.length := by rw [hAlen]; exact h_k_lt_KK + have h_t_matrix_bnd : ∀ i : Fin 16, ∀ j : Fin 16, + ((t_matrix.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := + fun i j => h_matrix_bnd ⟨k.val, h_k_lt_len⟩ i j + have h_t_s_bnd : ∀ i : Fin 16, ∀ j : Fin 16, + ((t_s.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := + fun i j => h_s_bnd ⟨k.val, h_lt⟩ i j + -- Current acc bound ≤ 2^30: combine inv conjunct (2) with budget PRE. + have h_acc_cur_bnd : ∀ n : Fin 256, (acc.val[n.val]!).val.natAbs ≤ 2^30 := by + intro n + have hb := h_inv_acc_bnd n.val n.isLt + have hp := h_acc_bnd n + -- hb : (acc[n]).val.natAbs ≤ (acc_init[n]).val.natAbs + k.val * 2^25 + -- hp : (acc_init[n]).val.natAbs + K.val * 2^25 ≤ 2^30 + have hk_le : k.val * 2^25 ≤ K.val * 2^25 := Nat.mul_le_mul_right _ h_le + omega + obtain ⟨p_pair, h_p_eq, h_p_bnd_rel, h_p_acc_post, h_p_cache_post⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_fill_cache_poly_fc t_matrix t_s t_cache acc + h_t_matrix_bnd h_t_s_bnd h_acc_cur_bnd) + set acc1 : Stage1FillCacheFC.Acc := p_pair.1 with hacc1_def + set cache_chunk1 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + p_pair.2 with hcc1_def + -- (5') p_pair.2 expressed as cache_chunk1 (for splicing back into s_cache). + -- (6) cache1 := s_cache.set k cache_chunk1. + set cache1 : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K := + cache.set k cache_chunk1 with hcache1_def + have h_cache1_at : cache1.val[k.val]! = cache_chunk1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq cache k k.val cache_chunk1 + ⟨rfl, by rw [h_cache_len]; exact h_lt⟩ + have h_cache1_ne : ∀ j : Nat, j ≠ k.val → + cache1.val[j]! = cache.val[j]! := by + intro j hj + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne cache k j cache_chunk1 + (fun h => hj h.symm) + -- (7) Body equation. + have h_body : + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt + { start := k, «end» := K } cache acc + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), cache1, acc1)) := by + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let pre ← libcrux_iot_ml_kem.matrix.entry K portable_ops_inst + matrix_A 0#usize k + let pre1 ← Aeneas.Std.Array.index_usize s_as_ntt k + let (pre2, index_mut_back) ← Aeneas.Std.Array.index_mut_usize cache k + let (accumulator1, pre3) ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache + portable_ops_inst pre pre1 acc pre2 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + index_mut_back pre3, accumulator1))) + : Result _) = _ + rw [h_matrix_entry'] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_s] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_cache] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let (accumulator1, pre3) ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache + portable_ops_inst t_matrix t_s acc t_cache + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + (cache.set k) pre3, accumulator1))) + : Result _) = _ + rw [h_p_eq] + simp only [Aeneas.Std.bind_tc_ok] + rfl + apply triple_of_ok_fc h_body + -- (8) Discharge the step_post. + show Stage1FillCacheFC.row0_step_post matrix_A s_as_ntt acc_init cache_init k + (.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), cache1, acc1)) + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + -- (9) Re-establish `row0_inv` at s_iter (= k+1). + show (Stage1FillCacheFC.row0_inv matrix_A s_as_ntt acc_init cache_init s_iter acc1 cache1).holds + unfold Stage1FillCacheFC.row0_inv + have h_inv_pure : + (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = (List.range s_iter.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + ∧ (∀ n : Nat, n < 256 → + (acc1.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + s_iter.val * 2^25) + ∧ (∀ c : Nat, c < s_iter.val → + accumulating_ntt_multiply_poly_cache_post + (s_as_ntt.val[c]!) (cache1.val[c]!)) + ∧ (∀ c : Nat, s_iter.val ≤ c → c < K.val → + cache1.val[c]! = cache_init.val[c]!) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · -- (a) Accumulator characterization at s_iter = k+1. + intro j hj ℓ hℓ + -- Use p_acc_post to extend the foldl from k to k+1. + -- p_acc_post : accumulating_ntt_multiply_poly_post t_matrix t_s acc acc1. + -- For each (j, ℓ): + -- mont_reduce_pure (lift_fe_int acc1[16j+ℓ].val) + -- = add_pure (mont_reduce_pure (lift_fe_int acc[16j+ℓ].val)) + -- (ntt_multiply_pure_no_acc (lift_chunk_mont t_matrix.coef[j]) + -- (lift_chunk_mont t_s.coef[j]) + -- zetas...).val[ℓ]! + -- IH at k: mont_reduce_pure (lift_fe_int acc[16j+ℓ].val) = foldl over [0, k). + -- Goal: mont_reduce_pure (lift_fe_int acc1[16j+ℓ].val) = foldl over [0, k+1). + have h_step_acc : + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (t_matrix.coefficients.val[j]!)) + (lift_chunk_mont (t_s.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) := by + have := h_p_acc_post + unfold accumulating_ntt_multiply_poly_post at this + exact this j hj ℓ hℓ + have h_ih := h_inv_acc j hj ℓ hℓ + rw [h_step_acc, h_ih] + -- LHS now: add_pure (foldl [0, k) init) (ntt_multiply ...[ℓ]!). + -- RHS: foldl [0, k+1) init = foldl ([0, k) ++ [k]) init + -- = foldl [k] (foldl [0, k) init) + -- = add_pure (foldl [0, k) init) (ntt_multiply at c=k ...[ℓ]!). + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + rw [List.range_succ, List.foldl_append] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((List.range k.val).foldl _ _) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (t_matrix.coefficients.val[j]!)) + (lift_chunk_mont (t_s.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) + = (List.foldl _ ((List.range k.val).foldl _ _) [k.val]) + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((List.range k.val).foldl _ _) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[k.val]!.coefficients.val[j]!)) + (lift_chunk_mont (s_as_ntt.val[k.val]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) + -- t_matrix = matrix_A.val[k.val]! and t_s = s_as_ntt.val[k.val]!. + rfl + · -- (b) Bound: ≤ acc_init[n] + s_iter.val * 2^25. + intro n hn + have h_p_bnd_n := h_p_bnd_rel ⟨n, hn⟩ + -- h_p_bnd_n : (acc1.val[⟨n, hn⟩.val]!).val.natAbs ≤ (acc.val[⟨n, hn⟩.val]!).val.natAbs + 2^25. + -- Convert Fin .val to plain n by definitional unfold. + have h_p_bnd_n' : (acc1.val[n]!).val.natAbs ≤ (acc.val[n]!).val.natAbs + 2^25 := + h_p_bnd_n + have h_inv_n := h_inv_acc_bnd n hn + -- h_inv_n : (acc[n]).val.natAbs ≤ (acc_init[n]).val.natAbs + k.val * 2^25. + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + -- Goal: (acc1[n]).val.natAbs ≤ (acc_init[n]).val.natAbs + (k.val + 1) * 2^25. + have h_arith : (k.val + 1) * 2^25 = k.val * 2^25 + 2^25 := by ring + rw [h_arith] + linarith [h_p_bnd_n', h_inv_n] + · -- (c) Cache populated for [0, s_iter). + intro c hc + rw [hs_iter_val] at hc + rcases Nat.lt_succ_iff_lt_or_eq.mp hc with hc_lt | hc_eq + · -- c < k: cache1[c] = cache[c], use h_inv_cache_done. + have hc_ne : c ≠ k.val := by omega + rw [h_cache1_ne c hc_ne] + exact h_inv_cache_done c hc_lt + · -- c = k: cache1[k] = cache_chunk1, use h_p_cache_post. + subst hc_eq + rw [h_cache1_at] + -- h_p_cache_post : accumulating_ntt_multiply_poly_cache_post t_s p_pair.2. + -- cache_chunk1 = p_pair.2; t_s = s_as_ntt.val[k.val]!. + exact h_p_cache_post + · -- (d) Cache unchanged for [s_iter, K). + intro c hc_ge hc_lt + rw [hs_iter_val] at hc_ge + have hc_ne : c ≠ k.val := by omega + rw [h_cache1_ne c hc_ne] + have hc_ge_k : k.val ≤ c := by omega + exact h_inv_cache_undone c hc_ge_k hc_lt + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ K, done. + have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt + { start := k, «end» := K } cache acc + = .ok (ControlFlow.done (cache, acc)) := by + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show Stage1FillCacheFC.row0_step_post matrix_A s_as_ntt acc_init cache_init k (.done (cache, acc)) + show (Stage1FillCacheFC.row0_inv matrix_A s_as_ntt acc_init cache_init K acc cache).holds + unfold Stage1FillCacheFC.row0_inv + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + K.val * 2^25) + ∧ (∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post + (s_as_ntt.val[c]!) (cache.val[c]!)) + ∧ (∀ c : Nat, K.val ≤ c → c < K.val → + cache.val[c]! = cache_init.val[c]!) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · intro j hj ℓ hℓ + have h_eq := h_inv_acc j hj ℓ hℓ + -- h_eq has (List.range k.val); rewrite to (List.range K.val) via hk_eq. + have h_rng : (List.range k.val) = (List.range K.val) := by rw [hk_eq] + rw [h_rng] at h_eq + exact h_eq + · intro n hn + have h_b := h_inv_acc_bnd n hn + have h_arith : k.val * 2^25 = K.val * 2^25 := by rw [hk_eq] + rw [h_arith] at h_b + exact h_b + · intro c hc + exact h_inv_cache_done c (by rw [hk_eq]; exact hc) + · intro c hc_ge hc_lt; omega + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +/-- L7.1 Stage 1 — `matrix.compute_As_plus_e_loop0`: the row-0 column loop. + Iterates over `j ∈ [0, K)`, accumulating column-j's contribution to the + I32 accumulator and populating `s_cache.val[j]!` with the per-column + NTT-multiply cache. + + POST: `row0_inv` holds at k = K, i.e. for all (j, ℓ) ∈ [0, 16)²: + `mont_reduce_pure (lift_fe_int acc[16j+ℓ].val)` equals the K-column + canonical-form sum of `ntt_multiply_pure_no_acc` outputs starting from + the initial accumulator's `mont_reduce_pure` lift. + + PRE: the standard 16×16 bound (3328) on `matrix_A` and `s_as_ntt`'s + entries, K·K matrix length, plus the accumulator BUDGET + `(acc_init[n]).val.natAbs + K·2^25 ≤ 2^30`. This budget is consumed by + the per-column inner forward dep (`accumulating_ntt_multiply_fill_cache_poly_fc`, + PRE `≤ 2^30`) at every iteration: the running accumulator satisfies + `acc[n] ≤ acc_init[n] + k·2^25 ≤ acc_init[n] + K·2^25 ≤ 2^30` for k ≤ K. + + Mirrors `accumulating_ntt_multiply_fill_cache_poly_fc` + but at the row-axis K-scale rather than the chunk-axis 16-scale. -/ +@[spec] +theorem compute_As_plus_e_loop0_fc + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt s_cache : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (accumulator : Std.Array Std.I32 256#usize) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) + (h_matrix_bnd : ∀ k : Fin matrix_A.length, ∀ i j : Fin 16, + ((matrix_A.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_s_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((s_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, + (accumulator.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop0 + (vectortraitsOperationsInst := portable_ops_inst) + { start := 0#usize, «end» := K } matrix_A s_as_ntt s_cache accumulator + ⦃ ⇓ p => ⌜ (Stage1FillCacheFC.row0_inv matrix_A s_as_ntt accumulator s_cache K p.2 p.1).holds ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop0 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, p) => + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt iter1 p.1 p.2) + (β := (Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + × Stage1FillCacheFC.Acc) + (s_cache, accumulator) + 0#usize K + (fun k p => Stage1FillCacheFC.row0_inv matrix_A s_as_ntt accumulator s_cache k p.2 p.1) + (by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0]; exact Nat.zero_le _) + (by + -- Base case at k = 0. + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_, ?_⟩ + · intro j hj ℓ hℓ + show Spec.mont_reduce_pure _ + = (List.range (0#usize : Std.Usize).val).foldl _ _ + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] + show Spec.mont_reduce_pure _ = (List.range 0).foldl _ _ + simp [List.range_zero, List.foldl_nil] + · intro n _; have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0']; omega + · intro c hc + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] at hc + exact absurd hc (Nat.not_lt_zero c) + · intro c _ _; trivial) + ?_) + · -- Post entailment: the final invariant holds at K. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Stage1FillCacheFC.row0_inv matrix_A s_as_ntt accumulator s_cache K r.2 r.1).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + show (Stage1FillCacheFC.row0_inv matrix_A s_as_ntt accumulator s_cache K r.2 r.1).holds + exact h_inv_holds + · -- Step entailment. + intro p k _h_ge h_le hinv + have h_step := compute_As_plus_e_loop0_step_lemma_fc + matrix_A s_as_ntt accumulator s_cache hAlen h_matrix_bnd h_s_bnd h_acc_bnd + p.2 p.1 k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', cache_acc⟩ | y + · have hP : Stage1FillCacheFC.row0_step_post matrix_A s_as_ntt accumulator s_cache k + (.cont (iter', cache_acc.1, cache_acc.2)) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Stage1FillCacheFC.row0_step_post] using hP + · have hP : Stage1FillCacheFC.row0_step_post matrix_A s_as_ntt accumulator s_cache k + (.done (y.1, y.2)) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Stage1FillCacheFC.row0_step_post] using hP + +end L7_1a_irreducible + +/-! ## §L7.1-loop1-loop0 — row-i (i ≥ 1) column loop scaffolding. + + Namespace `Stage2UseCacheFC` provides the invariant + step-post predicates + used to characterize `matrix.compute_As_plus_e_loop1_loop0` (the + K-iteration column loop run once per row i ∈ [1, K)) via + `loop_range_spec_usize`. Each iteration calls + `accumulating_ntt_multiply_use_cache` on column `j ∈ [0, K)`, adding + column j's contribution to the I32 accumulator. Unlike Stage 1 + (`compute_As_plus_e_loop0`), the cache is INPUT only — it was + populated by Stage 1's column loop on row 0, and is consumed + read-only here. + + Mirrors `Stage1FillCacheFC` minus the two cache-state + conjuncts (3)/(4), and with the matrix lane index parameterized by + the row index `i` (i.e. `i.val * K.val + c` rather than + `0 * K.val + c`). The per-column forward dep is + `accumulating_ntt_multiply_use_cache_poly_fc` + instead of `_fill_cache_poly_fc`. -/ + +namespace Stage2UseCacheFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +abbrev Acc := UseCacheFC.Acc +abbrev Poly := UseCacheFC.Poly + +/-- 2-conjunct invariant for the row-i (i ≥ 1) column loop. Tracks: + (1) accumulator characterization: for each (chunk j, lane ℓ) in + `[0, 16)²`, `Spec.mont_reduce_pure (lift_fe_int acc[16j+ℓ].val)` + equals init plus the canonical-form sum of column contributions + from columns `[0, k)` for the fixed row `i`. + (2) accumulator bound: `|acc.val[n]| ≤ |acc_init.val[n]| + k · 2^25`. -/ +def row_i_inv {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Acc) (i : Std.Usize) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + -- (1) Per-(chunk j, lane ℓ) accumulator: canonical-form k-column sum. + (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range k.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[i.val * K.val + c]!.coefficients.val[j]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + -- (2) Accumulator bound grows by 2^25 per column iteration. + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + k.val * 2^25)) + +/-- Step-post for `loop_range_spec_usize` over the accumulator only. -/ +def row_i_step_post {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Acc) (i : Std.Usize) (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : + Prop := + match r with + | .cont (iter', acc') => + k.val < K.val ∧ iter'.«end» = K + ∧ iter'.start.val = k.val + 1 + ∧ (row_i_inv matrix_A s_as_ntt acc_init i iter'.start acc').holds + | .done y => (row_i_inv matrix_A s_as_ntt acc_init i K y).holds + +end Stage2UseCacheFC + +-- Memory hygiene (rule 1 / SKILL §5.7 Idiom 2). Mirrors `L7_1a_irreducible` +-- — heavy POST predicates and the per-column forward dep +-- are made locally irreducible across the step lemma + outer Triple so that +-- elaboration does not whnf-explode through the 2-conjunct `row_i_inv` body or +-- the nested `∀ j : Fin 16, ∀ ℓ : Fin 16` accumulator characterization. +-- we do NOT mark +-- `Stage2UseCacheFC.row_i_inv` / `row_i_step_post` irreducible. +section L7_1b_irreducible +attribute [local irreducible] accumulating_ntt_multiply_poly_post +attribute [local irreducible] accumulating_ntt_multiply_poly_cache_post +attribute [local irreducible] Spec.ntt_multiply_pure_no_acc +attribute [local irreducible] Spec.mont_reduce_pure + +set_option maxHeartbeats 16000000 in +set_option maxRecDepth 1000 in +/-- Per-iteration FC step lemma for the row-i (i ≥ 1) column loop. Given + the `row_i_inv` invariant at step k and the strengthened PRE bounds + + the cache-post hypothesis, executing one body iteration of + `matrix.compute_As_plus_e_loop1_loop0.body` produces the + `row_i_step_post` (either `.cont` advancing the invariant to k+1 or + `.done` capping at K). + + Mirrors `compute_As_plus_e_loop0_step_lemma_fc` but + with three differences: + 1. No cache mutation: cache is INPUT only. + 2. Matrix lane uses `i.val * K.val + k.val` rather than + `0 * K.val + k.val = k.val`. + 3. Per-column forward dep is `accumulating_ntt_multiply_use_cache_poly_fc` instead of `_fill_cache_poly_fc`. This requires the + cache-post hypothesis at column k: + `accumulating_ntt_multiply_poly_cache_post (s_as_ntt[k]!) (s_cache[k]!)`. + We pass the OUTER ∀-quantified hypothesis through the step lemma so the + main theorem can hand it through unchanged. -/ +theorem compute_As_plus_e_loop1_loop0_step_lemma_fc + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt s_cache : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Stage2UseCacheFC.Acc) + (i : Std.Usize) (hi : i.val < K.val) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) + (h_matrix_bnd : ∀ k : Fin matrix_A.length, ∀ a b : Fin 16, + ((matrix_A.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_s_bnd : ∀ k : Fin K.val, ∀ a b : Fin 16, + ((s_as_ntt.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, + (acc_init.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) + (h_cache : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (s_as_ntt.val[c]!) (s_cache.val[c]!)) + (acc : Stage2UseCacheFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ K.val) + (h_inv : (Stage2UseCacheFC.row_i_inv matrix_A s_as_ntt acc_init i k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt s_cache i + { start := k, «end» := K } acc + ⦃ ⇓ r => ⌜ Stage2UseCacheFC.row_i_step_post matrix_A s_as_ntt acc_init i k r ⌝ ⦄ := by + have h_s_as_ntt_len : s_as_ntt.length = K.val := Std.Array.length_eq s_as_ntt + have h_s_cache_len : s_cache.length = K.val := Std.Array.length_eq s_cache + have h_acc_len : acc.length = 256 := Std.Array.length_eq acc + have h_acc_init_len : acc_init.length = 256 := Std.Array.length_eq acc_init + -- Destructure the 2-conjunct invariant. + obtain ⟨h_inv_acc, h_inv_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0.body + by_cases h_lt : k.val < K.val + · -- `Some k` branch. + -- (1) IteratorRange.next reduces to .ok (some k, { start := s_iter, end := K }). + have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, + ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- (2) matrix.entry reduces to .ok matrix_A.val[i.val * K.val + k.val]!. + have h_matrix_entry : + libcrux_iot_ml_kem.matrix.entry K portable_ops_inst matrix_A i k + = .ok (matrix_A.val[i.val * K.val + k.val]!) := + entry_eq_ok_fc_aux K matrix_A i k hAlen hi h_lt + set t_matrix : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + matrix_A.val[i.val * K.val + k.val]! with ht_matrix_def + -- (3) Array.index_usize s_as_ntt k reduces to .ok s_as_ntt[k.val]!. + set t_s : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + s_as_ntt.val[k.val]! with ht_s_def + have h_idx_s : Aeneas.Std.Array.index_usize s_as_ntt k = .ok t_s := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq s_as_ntt k + (by rw [h_s_as_ntt_len]; exact h_lt) + -- (4) Array.index_usize s_cache k reduces to .ok s_cache[k.val]!. (read-only) + set t_cache : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + s_cache.val[k.val]! with ht_cache_def + have h_idx_cache : Aeneas.Std.Array.index_usize s_cache k = .ok t_cache := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq s_cache k + (by rw [h_s_cache_len]; exact h_lt) + -- (5) Apply L6.3c per-column forward dep at column k (use_cache flavor). + -- Per-lane bounds on t_matrix and t_s (16×16 lanes). + have hK_pos : 0 < K.val := Nat.lt_of_le_of_lt (Nat.zero_le _) h_lt + have h_iKk_lt_KK : i.val * K.val + k.val < K.val * K.val := by + have h_iK_lt : i.val * K.val + K.val ≤ K.val * K.val := by + have hstep : (i.val + 1) * K.val ≤ K.val * K.val := + Nat.mul_le_mul_right _ hi + have h_expand : (i.val + 1) * K.val = i.val * K.val + K.val := by ring + rw [h_expand] at hstep; exact hstep + omega + have h_iKk_lt_len : i.val * K.val + k.val < matrix_A.length := by + rw [hAlen]; exact h_iKk_lt_KK + have h_t_matrix_bnd : ∀ a : Fin 16, ∀ b : Fin 16, + ((t_matrix.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_matrix_bnd ⟨i.val * K.val + k.val, h_iKk_lt_len⟩ a b + have h_t_s_bnd : ∀ a : Fin 16, ∀ b : Fin 16, + ((t_s.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_s_bnd ⟨k.val, h_lt⟩ a b + -- Cache-post hypothesis at column k. + have h_cache_at_k : accumulating_ntt_multiply_poly_cache_post t_s t_cache := + h_cache k.val h_lt + -- Current acc bound ≤ 2^30: combine inv conjunct (2) with budget PRE. + have h_acc_cur_bnd : ∀ n : Fin 256, (acc.val[n.val]!).val.natAbs ≤ 2^30 := by + intro n + have hb := h_inv_acc_bnd n.val n.isLt + have hp := h_acc_bnd n + have hk_le : k.val * 2^25 ≤ K.val * 2^25 := Nat.mul_le_mul_right _ h_le + omega + obtain ⟨acc1, h_acc1_eq, h_acc1_bnd_rel, h_acc1_post⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_use_cache_poly_fc t_matrix t_s t_cache acc + h_t_matrix_bnd h_t_s_bnd h_acc_cur_bnd h_cache_at_k) + -- (6) Body equation. + have h_body : + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt s_cache i + { start := k, «end» := K } acc + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), acc1)) := by + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let pre ← libcrux_iot_ml_kem.matrix.entry K portable_ops_inst + matrix_A i k + let pre1 ← Aeneas.Std.Array.index_usize s_as_ntt k + let pre2 ← Aeneas.Std.Array.index_usize s_cache k + let accumulator1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache + portable_ops_inst pre pre1 acc pre2 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), accumulator1))) + : Result _) = _ + rw [h_matrix_entry] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_s] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_cache] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_acc1_eq] + rfl + apply triple_of_ok_fc h_body + -- (7) Discharge the step_post. + show Stage2UseCacheFC.row_i_step_post matrix_A s_as_ntt acc_init i k + (.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), acc1)) + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + -- (8) Re-establish `row_i_inv` at s_iter (= k+1). + show (Stage2UseCacheFC.row_i_inv matrix_A s_as_ntt acc_init i s_iter acc1).holds + unfold Stage2UseCacheFC.row_i_inv + have h_inv_pure : + (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = (List.range s_iter.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[i.val * K.val + c]!.coefficients.val[j]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + ∧ (∀ n : Nat, n < 256 → + (acc1.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + s_iter.val * 2^25) := by + refine ⟨?_, ?_⟩ + · -- (a) Accumulator characterization at s_iter = k+1. + intro j hj ℓ hℓ + have h_step_acc : + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (t_matrix.coefficients.val[j]!)) + (lift_chunk_mont (t_s.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) := by + have := h_acc1_post + unfold accumulating_ntt_multiply_poly_post at this + exact this j hj ℓ hℓ + have h_ih := h_inv_acc j hj ℓ hℓ + rw [h_step_acc, h_ih] + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + rw [List.range_succ, List.foldl_append] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((List.range k.val).foldl _ _) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (t_matrix.coefficients.val[j]!)) + (lift_chunk_mont (t_s.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) + = (List.foldl _ ((List.range k.val).foldl _ _) [k.val]) + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((List.range k.val).foldl _ _) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[i.val * K.val + k.val]!.coefficients.val[j]!)) + (lift_chunk_mont (s_as_ntt.val[k.val]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) + rfl + · -- (b) Bound: ≤ acc_init[n] + s_iter.val * 2^25. + intro n hn + have h_acc1_bnd_n := h_acc1_bnd_rel ⟨n, hn⟩ + have h_acc1_bnd_n' : (acc1.val[n]!).val.natAbs ≤ (acc.val[n]!).val.natAbs + 2^25 := + h_acc1_bnd_n + have h_inv_n := h_inv_acc_bnd n hn + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + have h_arith : (k.val + 1) * 2^25 = k.val * 2^25 + 2^25 := by ring + rw [h_arith] + linarith [h_acc1_bnd_n', h_inv_n] + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ K, done. + have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt s_cache i + { start := k, «end» := K } acc + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show Stage2UseCacheFC.row_i_step_post matrix_A s_as_ntt acc_init i k (.done acc) + show (Stage2UseCacheFC.row_i_inv matrix_A s_as_ntt acc_init i K acc).holds + unfold Stage2UseCacheFC.row_i_inv + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[i.val * K.val + c]!.coefficients.val[j]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + K.val * 2^25) := by + refine ⟨?_, ?_⟩ + · intro j hj ℓ hℓ + have h_eq := h_inv_acc j hj ℓ hℓ + have h_rng : (List.range k.val) = (List.range K.val) := by rw [hk_eq] + rw [h_rng] at h_eq + exact h_eq + · intro n hn + have h_b := h_inv_acc_bnd n hn + have h_arith : k.val * 2^25 = K.val * 2^25 := by rw [hk_eq] + rw [h_arith] at h_b + exact h_b + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +/-- L7.1 Stage 2 — `matrix.compute_As_plus_e_loop1_loop0`: the row-i + (i ≥ 1) column loop. Iterates over `j ∈ [0, K)`, accumulating + column-j's contribution to the I32 accumulator via + `accumulating_ntt_multiply_use_cache`. The cache is INPUT only — + populated by Stage 1's row-0 column loop and consumed read-only here. + + POST: `row_i_inv` holds at k = K, i.e. for all (j, ℓ) ∈ [0, 16)²: + `mont_reduce_pure (lift_fe_int acc[16j+ℓ].val)` equals the K-column + canonical-form sum at row `i` of `ntt_multiply_pure_no_acc` outputs + starting from the initial accumulator's `mont_reduce_pure` lift. + + PRE: standard 16×16 bound (3328) on matrix and s_as_ntt entries, the + K·K matrix-length axiom, `hK : K.val ≤ 4`, `hi : i.val < K.val`, the + additive accumulator BUDGET `(acc_init[n]).val.natAbs + K·2^25 ≤ 2^30`, + and the cache-post hypothesis `h_cache` — at every column c < K, + `s_cache.val[c]!` satisfies `accumulating_ntt_multiply_poly_cache_post` + against `s_as_ntt.val[c]!`. The latter is established by Stage 1's + final invariant (row-0 column loop populates the cache). + + Mirrors `compute_As_plus_e_loop0_fc` minus cache + threading; the cache passes through as a read-only parameter. -/ +@[spec] +theorem compute_As_plus_e_loop1_loop0_fc + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt s_cache : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (accumulator : Std.Array Std.I32 256#usize) + (i : Std.Usize) + (hi : i.val < K.val) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) + (h_matrix_bnd : ∀ k : Fin matrix_A.length, ∀ a b : Fin 16, + ((matrix_A.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_s_bnd : ∀ k : Fin K.val, ∀ a b : Fin 16, + ((s_as_ntt.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, + (accumulator.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) + (h_cache : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (s_as_ntt.val[c]!) (s_cache.val[c]!)) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0 + (vectortraitsOperationsInst := portable_ops_inst) + { start := 0#usize, «end» := K } matrix_A s_as_ntt s_cache accumulator i + ⦃ ⇓ p => ⌜ (Stage2UseCacheFC.row_i_inv matrix_A s_as_ntt accumulator i K p).holds ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt s_cache i + iter1 acc1) + (β := Stage2UseCacheFC.Acc) + accumulator + 0#usize K + (fun k acc => Stage2UseCacheFC.row_i_inv matrix_A s_as_ntt accumulator i k acc) + (by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0]; exact Nat.zero_le _) + (by + -- Base case at k = 0. + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_⟩ + · intro j hj ℓ hℓ + show Spec.mont_reduce_pure _ + = (List.range (0#usize : Std.Usize).val).foldl _ _ + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] + show Spec.mont_reduce_pure _ = (List.range 0).foldl _ _ + simp [List.range_zero, List.foldl_nil] + · intro n _; have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0']; omega) + ?_) + · -- Post entailment: the final invariant holds at K. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Stage2UseCacheFC.row_i_inv matrix_A s_as_ntt accumulator i K r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + show (Stage2UseCacheFC.row_i_inv matrix_A s_as_ntt accumulator i K r).holds + exact h_inv_holds + · -- Step entailment. + intro acc k _h_ge h_le hinv + have h_step := compute_As_plus_e_loop1_loop0_step_lemma_fc + matrix_A s_as_ntt s_cache accumulator i hi hAlen h_matrix_bnd h_s_bnd h_acc_bnd + h_cache acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Stage2UseCacheFC.row_i_step_post matrix_A s_as_ntt accumulator i k + (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Stage2UseCacheFC.row_i_step_post] using hP + · have hP : Stage2UseCacheFC.row_i_step_post matrix_A s_as_ntt accumulator i k + (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Stage2UseCacheFC.row_i_step_post] using hP + +end L7_1b_irreducible + +/-! ## §L7.1-loop1 — outer rows loop (rows i ∈ [start, K)) scaffolding. + + Namespace `Stage3MontStripFC` provides the invariant + step-post predicates + for `matrix.compute_As_plus_e_loop1` (the outer rows loop) via + `loop_range_spec_usize`. Each iteration covers one full row i: + re-zeros accumulator, calls `compute_As_plus_e_loop1_loop0_fc` + (Stage 2) for the column sum, converts via `reducing_from_i32_array` + to Mont FE form, then applies `add_standard_error_reduce` (×1353 + + add error) to produce the canonical FE row. + + Composes Stage 2 + L6.7 (poly_reducing_from_i32_array_fc) + L6.5 + (add_standard_error_reduce_fc) per row. The per-lane invariant equation + has the form `(lift_poly t_as_ntt[r]).val[lane]! = + add_pure (mul_pure (canonical_row_sum_lane ...) (lift_fe_mont 1353)) + ((lift_poly error[r]).val[lane]!)`, where + `canonical_row_sum_lane` absorbs ONE Mont→canonical bridge step + (mul_pure × lift_fe_mont 1353) over the K-fold sum-in-Mont produced + by Stage 2+L6.7; the OUTER × 1353 in the invariant comes from L6.5's + own `mul_pure self (lift_fe_mont 1353)` step. -/ + +namespace Stage3MontStripFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +abbrev TVec (K : Std.Usize) := Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K + +abbrev Acc := Std.Array Std.I32 256#usize + +/-- The canonical-form per-(row i, chunk j, lane-within-chunk q) value: + the column-K sum of `Spec.ntt_multiply_pure_no_acc` contributions + extracted at chunk j, lane q, with the OUTER Mont→canonical bridge + (`mul_pure _ (lift_fe_mont 1353)`) folded in. This equals + `(lift_poly pre1).val[16*j+q]!` after the Stage 2 + L6.7 + bridge + composition, where `pre1` is L6.7's Mont-form polynomial output. + + Per-row composition with L6.5 then gives the canonical FE row: + `(lift_poly t_as_ntt[r]!).val[lane]! + = add_pure (mul_pure (canonical_row_sum_lane ...) (lift_fe_mont 1353)) + ((lift_poly error[r]!).val[lane]!)` + where lane = 16*j + q. + + The foldl seed `Spec.mont_reduce_pure (lift_fe_int 0)` matches the + zero-init accumulator after Stage 2 — each iteration of the outer + loop re-zeros via `Array.repeat 256#usize (classify 0#i32)`, so the + `acc_init` slot in Stage 2's invariant collapses to the zero + accumulator's `mont_reduce_pure (lift_fe_int 0)` constant. -/ +noncomputable def canonical_row_sum_lane + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : TVec K) (i : Nat) (j q : Nat) : + hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[i * K.val + c]!.coefficients.val[j]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[q]!)) + (Spec.mont_reduce_pure (lift_fe_int 0))) + (lift_fe_mont (1353#i16 : Std.I16)) + +/-- 2-conjunct invariant for the outer rows loop. Tracks: + (1) Per-completed-row characterization: for each row `r ∈ [start, k)`, + and each lane `ℓ ∈ [0, 256)`, + `(lift_poly t_as_ntt.val[r]!).val[ℓ]!` + = `add_pure (mul_pure (canonical_row_sum_lane matrix_A s_as_ntt r (ℓ/16) (ℓ%16)) + (lift_fe_mont 1353)) + ((lift_poly error_as_ntt.val[r]!).val[ℓ]!)`. + (2) Unchanged rows: for each row `r ∈ [0, K)` with + `r < start.val ∨ k.val ≤ r`, + `t_as_ntt.val[r]! = t_as_ntt_init.val[r]!`. -/ +def rows_inv {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt error_as_ntt : TVec K) + (t_as_ntt_init : TVec K) (start : Std.Usize) : + Std.Usize → TVec K → Acc → Result Prop := + fun k t_as_ntt _acc => pure ( + (∀ r : Nat, start.val ≤ r → r < k.val → ∀ ℓ : Nat, ℓ < 256 → + (lift_poly t_as_ntt.val[r]!).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (canonical_row_sum_lane matrix_A s_as_ntt r (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly error_as_ntt.val[r]!).val[ℓ]!)) + ∧ (∀ r : Nat, r < K.val → (r < start.val ∨ k.val ≤ r) → + t_as_ntt.val[r]! = t_as_ntt_init.val[r]!)) + +/-- Step-post for `loop_range_spec_usize` over (t_as_ntt, accumulator). -/ +def rows_step_post {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt error_as_ntt : TVec K) + (t_as_ntt_init : TVec K) (start : Std.Usize) (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × TVec K × Acc) + (TVec K × Acc)) : + Prop := + match r with + | .cont (iter', t', acc') => + k.val < K.val ∧ iter'.«end» = K + ∧ iter'.start.val = k.val + 1 + ∧ (rows_inv matrix_A s_as_ntt error_as_ntt t_as_ntt_init start + iter'.start t' acc').holds + | .done y => (rows_inv matrix_A s_as_ntt error_as_ntt t_as_ntt_init start + K y.1 y.2).holds + +end Stage3MontStripFC + +/-! ## §L7.1 Stage 4 bridge lemmas — `chunk_at (lift_poly _)` ↔ `lift_chunk_mont _`. + + These two helpers connect the `Spec.multiply_ntts_pure_eq_chunked_no_acc` + side (which operates on `chunk_at (lift_poly _)` chunks, i.e. canonical + `lift_fe` lanes) to the `Stage3MontStripFC.canonical_row_sum_lane` side (which + operates on `lift_chunk_mont _` chunks, i.e. Mont-stripped `lift_fe_mont` + lanes). The relationship per lane is exactly the FE-level + `lift_fe_mont_mul_1353_eq_lift_fe` identity at. + + Both helpers are private to FCTargets and used only by the L7.1 Stage 4 + closing argument. -/ + +namespace Stage3MontStripFC + +set_option maxHeartbeats 1000000 in +/-- Per-lane bridge: `Spec.chunk_at (lift_poly p) j` interprets each I16 + lane via `lift_fe` (canonical, value = `x.val mod q`), while + `lift_chunk_mont p.coefficients.val[j]!` uses `lift_fe_mont` (Mont- + stripped, value = `x.val · R⁻¹ mod q`). The conversion factor is + `lift_fe_mont 1353 = R` (since `1353 = R² mod q`, so + `lift_fe_mont 1353 = 1353 · R⁻¹ = R²·R⁻¹ = R mod q`). Thus + `lift_fe x = (lift_fe_mont x) · R = (lift_fe_mont x) · (lift_fe_mont 1353)`, + which is exactly `lift_fe_mont_mul_1353_eq_lift_fe`. -/ +theorem chunk_at_lift_poly_lane + (p : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (j : Nat) (h_j : j < 16) (q : Nat) (h_q : q < 16) : + (Spec.chunk_at (lift_poly p) j).val[q]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont p.coefficients.val[j]!).val[q]!) + (lift_fe_mont (1353#i16 : Std.I16)) := by + -- Pin the underlying I16 lane shared by both sides. + set x : Std.I16 := + (p.coefficients.val[j]!).elements.val[q]! with hx_def + -- The elements list has length 16 (from PortableVector's invariant). + have h_elem_len : ((p.coefficients.val[j]!).elements.val).length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + -- (A) RHS factor: `(lift_chunk_mont p.coefficients.val[j]!).val[q]! = lift_fe_mont x`. + have h_mont : (lift_chunk_mont p.coefficients.val[j]!).val[q]! = lift_fe_mont x := by + unfold lift_chunk_mont + show (((p.coefficients.val[j]!).elements.val).map lift_fe_mont)[q]! + = lift_fe_mont x + have h_len : (((p.coefficients.val[j]!).elements.val).map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_elem_len + rw [getElem!_pos _ q (by rw [h_len]; exact h_q)] + rw [List.getElem_map] + -- Goal: lift_fe_mont ((p.coefficients.val[j]!).elements.val[q]) = lift_fe_mont x. + -- Convert `[q]` (with bounds proof) back to `[q]!` to match `x`'s definition. + rw [show ((p.coefficients.val[j]!).elements.val)[q] + = ((p.coefficients.val[j]!).elements.val)[q]! from + (getElem!_pos _ q (by rw [h_elem_len]; exact h_q)).symm] + -- (B) LHS: `(Spec.chunk_at (lift_poly p) j).val[q]! = lift_fe x`. + have h_plain : (Spec.chunk_at (lift_poly p) j).val[q]! = lift_fe x := by + unfold Spec.chunk_at + show ((List.range 16).map (fun j' => (lift_poly p).val[16 * j + j']!))[q]! + = lift_fe x + have h_len_outer : ((List.range 16).map + (fun j' => (lift_poly p).val[16 * j + j']!)).length = 16 := by simp + rw [getElem!_pos _ q (by rw [h_len_outer]; exact h_q)] + rw [List.getElem_map, List.getElem_range] + -- Goal: (lift_poly p).val[16 * j + q]! = lift_fe x. + have h_lane : 16 * j + q < 256 := by omega + unfold lift_poly + show ((List.range 256).map (fun n => + lift_fe (p.coefficients.val[n / 16]!).elements.val[n % 16]!))[16 * j + q]! + = lift_fe x + have h_len_inner : ((List.range 256).map (fun n => + lift_fe (p.coefficients.val[n / 16]!).elements.val[n % 16]!)).length = 256 := by simp + rw [getElem!_pos _ (16 * j + q) (by rw [h_len_inner]; exact h_lane)] + rw [List.getElem_map, List.getElem_range] + -- Goal: lift_fe (p.coefficients.val[(16*j+q)/16]!).elements.val[(16*j+q)%16]! = lift_fe x. + have h_div : (16 * j + q) / 16 = j := by omega + have h_mod : (16 * j + q) % 16 = q := by omega + rw [h_div, h_mod] + rw [h_mont, h_plain] + -- Goal: lift_fe x = mul_pure (lift_fe_mont x) (lift_fe_mont 1353). + rw [lift_fe_mont_mul_1353_eq_lift_fe] + +/-- Per-lane reduction of `Spec.ntt_multiply_pure_no_acc` projection. + + Used by `ntt_multiply_pure_no_acc_lane_scale` to give both sides a + uniform `if q%2=0 ...` shape over the four operand projections + (`.val[2·(q/2)]!`, `.val[2·(q/2)+1]!`). The reduction is `rfl` after + unfolding `Spec.ntt_multiply_pure_no_acc` and projecting through + `Std.Array.make` + `(List.range 16).map`. + + The 8-zeta list is materialized inline (no `let`) so downstream + `simp`/`rw` substitutions see explicit `FieldElement.{add,mul,neg}_pure` + head symbols for `zmodOfFE_{add,mul}_pure` simp-set rewrites. -/ +theorem ntt_multiply_pure_no_acc_val_q + (a b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (zeta0 zeta1 zeta2 zeta3 : hacspec_ml_kem.parameters.FieldElement) + (q : Nat) (h_q : q < 16) : + (Spec.ntt_multiply_pure_no_acc a b zeta0 zeta1 zeta2 zeta3).val[q]! + = (let neg := libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure + let zeta_q : hacspec_ml_kem.parameters.FieldElement := + [zeta0, neg zeta0, zeta1, neg zeta1, + zeta2, neg zeta2, zeta3, neg zeta3][q / 2]! + if q % 2 = 0 then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + a.val[2 * (q / 2)]! b.val[2 * (q / 2)]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + a.val[2 * (q / 2) + 1]! b.val[2 * (q / 2) + 1]!) + zeta_q) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + a.val[2 * (q / 2)]! b.val[2 * (q / 2) + 1]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + a.val[2 * (q / 2) + 1]! b.val[2 * (q / 2)]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rw [show ∀ (l : List hacspec_ml_kem.parameters.FieldElement) + (h : l.length = (16#usize : Std.Usize).val), + (Std.Array.make 16#usize l h).val[q]! = l[q]! from fun _ _ => rfl, + List.getElem!_eq_getElem?_getD, List.getElem?_map, List.getElem?_range h_q, + Option.map_some, Option.getD_some] + +set_option maxHeartbeats 8000000 in +/-- **Bilinearity of `Spec.ntt_multiply_pure_no_acc` over a per-lane scalar.** + + If inputs `a, b` scale `am, bm` lane-wise by a common scalar `c`, then + each output lane scales by `c²`. Proof: + - Reduce both `.val[q]!` lookups via `ntt_multiply_pure_no_acc_val_q` + (uniform `if q%2=0 ...` form over `2·(q/2), 2·(q/2)+1`). + - Substitute `h_a, h_b` at those four positions, exposing + `mul_pure am_k c` / `mul_pure bm_k c` factors. + - Case-split `q % 2 = 0 ∨ q % 2 = 1` and project via `zmodOfFE` to + `ZMod 3329`; canonical round-trip + `ring` closes each branch. + + Used by Helper 2 (`ntt_multiply_pure_no_acc_chunk_at_lift_poly_eq`) as + a 1-line corollary with `c = lift_fe_mont 1353`. -/ +theorem ntt_multiply_pure_no_acc_lane_scale + (a am b bm : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (c : hacspec_ml_kem.parameters.FieldElement) + (h_a : ∀ k : Nat, k < 16 → a.val[k]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure am.val[k]! c) + (h_b : ∀ k : Nat, k < 16 → b.val[k]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure bm.val[k]! c) + (zeta0 zeta1 zeta2 zeta3 : hacspec_ml_kem.parameters.FieldElement) + (q : Nat) (h_q : q < 16) : + (Spec.ntt_multiply_pure_no_acc a b zeta0 zeta1 zeta2 zeta3).val[q]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.ntt_multiply_pure_no_acc am bm zeta0 zeta1 zeta2 zeta3).val[q]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure c c) := by + have h_2pi : 2 * (q / 2) < 16 := by omega + have h_2pi1 : 2 * (q / 2) + 1 < 16 := by omega + rw [ntt_multiply_pure_no_acc_val_q a b _ _ _ _ q h_q, + ntt_multiply_pure_no_acc_val_q am bm _ _ _ _ q h_q] + rw [h_a (2 * (q / 2)) h_2pi, h_a (2 * (q / 2) + 1) h_2pi1, + h_b (2 * (q / 2)) h_2pi, h_b (2 * (q / 2) + 1) h_2pi1] + -- Helper: canonical round-trip closer. + have h_close : ∀ s t : hacspec_ml_kem.parameters.FieldElement, + s.val.val < 3329 → t.val.val < 3329 → + zmodOfFE s = zmodOfFE t → s = t := by + intro s t hs ht heq + rw [← feOfZMod_zmodOfFE_of_canonical s hs, + ← feOfZMod_zmodOfFE_of_canonical t ht, heq] + -- Helper: `Canonical x → x.val.val < 3329`. + have h_canon_to_lt : ∀ x : hacspec_ml_kem.parameters.FieldElement, + libcrux_iot_ml_kem.Spec.Pure.Canonical x → x.val.val < 3329 := by + intro x hx + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at hx + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at hx + exact hx + -- Case split on q % 2. + rcases (show q % 2 = 0 ∨ q % 2 = 1 from by omega) with h_par | h_par + · -- q % 2 = 0 branch. + rw [if_pos h_par, if_pos h_par] + apply h_close + · apply h_canon_to_lt + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + · apply h_canon_to_lt + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + · simp only [L2_8c.zmodOfFE_add_pure, L2_8c.zmodOfFE_mul_pure] + ring + · -- q % 2 = 1 branch. + have h_par_ne : q % 2 ≠ 0 := by omega + rw [if_neg h_par_ne, if_neg h_par_ne] + apply h_close + · apply h_canon_to_lt + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + · apply h_canon_to_lt + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + · simp only [L2_8c.zmodOfFE_add_pure, L2_8c.zmodOfFE_mul_pure] + ring + +/-- **L7.1 Stage 4 chunked bilinearity bridge (Helper 2).** + + Connects per-chunk `Spec.chunk_at (lift_poly _)` (canonical lift, used + by .4 `Spec.multiply_ntts_pure_eq_chunked_no_acc`) to + `lift_chunk_mont _` (Mont-stripped, used by `canonical_row_sum_lane`) + via the bilinearity of `Spec.ntt_multiply_pure_no_acc`. Both inputs + differ from the Mont versions by per-lane `× lift_fe_mont 1353` (= R + in ZMod 3329), so the per-lane output differs by `(lift_fe_mont 1353)²`. + + 1-line composition: `ntt_multiply_pure_no_acc_lane_scale` with + `c = lift_fe_mont 1353` and lane hypothesis `chunk_at_lift_poly_lane`. -/ +theorem ntt_multiply_pure_no_acc_chunk_at_lift_poly_eq + (a b : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (j : Nat) (h_j : j < 16) + (zeta0 zeta1 zeta2 zeta3 : hacspec_ml_kem.parameters.FieldElement) + (q : Nat) (h_q : q < 16) : + (Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at (lift_poly a) j) (Spec.chunk_at (lift_poly b) j) + zeta0 zeta1 zeta2 zeta3).val[q]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont a.coefficients.val[j]!) + (lift_chunk_mont b.coefficients.val[j]!) + zeta0 zeta1 zeta2 zeta3).val[q]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (1353#i16 : Std.I16)) + (lift_fe_mont (1353#i16 : Std.I16))) := + ntt_multiply_pure_no_acc_lane_scale + (Spec.chunk_at (lift_poly a) j) (lift_chunk_mont a.coefficients.val[j]!) + (Spec.chunk_at (lift_poly b) j) (lift_chunk_mont b.coefficients.val[j]!) + (lift_fe_mont (1353#i16 : Std.I16)) + (fun k h_k => chunk_at_lift_poly_lane a j h_j k h_k) + (fun k h_k => chunk_at_lift_poly_lane b j h_j k h_k) + zeta0 zeta1 zeta2 zeta3 q h_q + +end Stage3MontStripFC + +-- Memory hygiene (rule 1 / SKILL §5.7 Idiom 2). Heavy `accumulating_ntt_multiply_*_post` +-- predicates + the `canonical_row_sum_lane` foldl are made locally irreducible across +-- the Stage 3 step lemma + main Triple to keep elaboration tractable through the 2-conjunct +-- `rows_inv` body. We do NOT mark `Stage3MontStripFC.rows_inv` / `rows_step_post` +-- irreducible — keeping them reducible preserves the `simpa`-based +-- destructure of `h_inv`. +section L7_1c_irreducible +attribute [local irreducible] accumulating_ntt_multiply_poly_post +attribute [local irreducible] accumulating_ntt_multiply_poly_cache_post +attribute [local irreducible] Spec.ntt_multiply_pure_no_acc +attribute [local irreducible] Spec.mont_reduce_pure +attribute [local irreducible] Stage3MontStripFC.canonical_row_sum_lane + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the outer rows loop. Given the + `Stage3MontStripFC.rows_inv` invariant at step `k` and the strengthened PRE bounds, + executing one body iteration of `matrix.compute_As_plus_e_loop1.body` + produces the `Stage3MontStripFC.rows_step_post` (either `.cont` advancing the + invariant to `k+1` or `.done` capping at `K`). + + Per-iteration step composes: + 1. `iter.next` → row index `i = k`. + 2. `classify 0#i32` → `i1 = 0#i32`; `Array.repeat 256#usize 0#i32` → + `accumulator1` (zeroed accumulator). + 3. Stage 2 (`compute_As_plus_e_loop1_loop0_fc`) on row `i` with the zeroed + accumulator. Yields `accumulator2` satisfying `Stage2UseCacheFC.row_i_inv`. + 4. `lift (Array.to_slice accumulator2)` → `s`. + 5. `Array.index_mut_usize t_as_ntt i` → `(pre, set t_as_ntt i)`. + 6. L6.7 (`poly_reducing_from_i32_array_fc`, NOW STRENGTHENED) on `s` and + `pre`. Yields Mont-form `t1` with per-lane bound `≤ 4993`. + 7. `t_as_ntt1 := set t_as_ntt i t1`. + 8. `Array.index_mut_usize t_as_ntt1 i` → `(t1, set t_as_ntt1 i)`. + 9. `Array.index_usize error_as_ntt i` → `error_as_ntt[i]`. + 10. L6.5 (`add_standard_error_reduce_fc`) on (t1, error_as_ntt[i]). Uses + `4993 ≤ 32767` to discharge the L6.5 self-bound PRE. + 11. `a := set t_as_ntt1 i pre4`. + 12. Re-establish `rows_inv` at `k+1`. -/ +theorem compute_As_plus_e_loop1_step_lemma_fc + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt error_as_ntt s_cache : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (t_as_ntt_init : Stage3MontStripFC.TVec K) + (start : Std.Usize) + (hK : K.val ≤ 4) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) + (h_matrix_bnd : ∀ k : Fin matrix_A.length, ∀ a b : Fin 16, + ((matrix_A.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_s_bnd : ∀ k : Fin K.val, ∀ a b : Fin 16, + ((s_as_ntt.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_error_bnd : ∀ k : Fin K.val, ∀ a b : Fin 16, + ((error_as_ntt.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 29439) + (h_cache : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (s_as_ntt.val[c]!) (s_cache.val[c]!)) + (t_as_ntt : Stage3MontStripFC.TVec K) (accumulator : Stage3MontStripFC.Acc) + (k : Std.Usize) (h_ge : start.val ≤ k.val) (h_le : k.val ≤ K.val) + (h_inv : (Stage3MontStripFC.rows_inv matrix_A s_as_ntt error_as_ntt t_as_ntt_init start + k t_as_ntt accumulator).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt error_as_ntt s_cache + { start := k, «end» := K } t_as_ntt accumulator + ⦃ ⇓ r => ⌜ Stage3MontStripFC.rows_step_post matrix_A s_as_ntt error_as_ntt t_as_ntt_init + start k r ⌝ ⦄ := by + have h_t_as_ntt_len : t_as_ntt.length = K.val := Std.Array.length_eq t_as_ntt + have h_error_len : error_as_ntt.length = K.val := Std.Array.length_eq error_as_ntt + -- Destructure the 2-conjunct invariant. + obtain ⟨h_inv_done, h_inv_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1.body + by_cases h_lt : k.val < K.val + · -- `Some k` branch (i = k). + -- (1) IteratorRange.next reduces to (some k, {start := s_iter, end := K}). + have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, + ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- (2) classify 0#i32 = .ok 0#i32. + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify (0#i32 : Std.I32) + = .ok (0#i32 : Std.I32) := rfl + -- (3) Array.repeat 256 0#i32 — fresh zeroed accumulator. + set acc_zero : Stage3MontStripFC.Acc := + Aeneas.Std.Array.repeat 256#usize (0#i32 : Std.I32) with h_acc_zero_def + have h_acc_zero_val : acc_zero.val = List.replicate 256 (0#i32 : Std.I32) := by + show (Aeneas.Std.Array.repeat 256#usize (0#i32 : Std.I32)).val = _ + simp [Aeneas.Std.Array.repeat_val] + have h_acc_zero_get : ∀ n : Nat, n < 256 → + acc_zero.val[n]! = (0#i32 : Std.I32) := by + intro n hn + rw [h_acc_zero_val] + rw [getElem!_pos (List.replicate 256 (0#i32 : Std.I32)) n + (by rw [List.length_replicate]; exact hn)] + exact List.getElem_replicate _ + have h_acc_zero_bnd : ∀ n : Fin 256, + (acc_zero.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30 := by + intro n + rw [h_acc_zero_get n.val n.isLt] + have h0 : ((0#i32 : Std.I32).val).natAbs = 0 := rfl + rw [h0] + have hK4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + have : 4 * 2^25 ≤ 2^30 := by decide + omega + -- (4) Apply Stage 2 forward dep at row i = k with the zeroed accumulator. + have h_stage2 := + compute_As_plus_e_loop1_loop0_fc matrix_A s_as_ntt s_cache acc_zero k h_lt + hAlen h_matrix_bnd h_s_bnd h_acc_zero_bnd h_cache + obtain ⟨acc_final, h_acc_final_eq, h_acc_final_inv⟩ := triple_exists_ok_fc h_stage2 + -- Destructure the Stage 2 POST into its 2 conjuncts. + obtain ⟨h_acc_final_lane, h_acc_final_bnd⟩ := by + simpa [Stage2UseCacheFC.row_i_inv, Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + using h_acc_final_inv + -- (5) `lift (Array.to_slice acc_final) = .ok acc_final.to_slice`. + set acc_slice : Slice Std.I32 := Aeneas.Std.Array.to_slice acc_final with h_acc_slice_def + have h_acc_slice_val : acc_slice.val = acc_final.val := + Aeneas.Std.Array.val_to_slice acc_final + have h_acc_slice_len : acc_slice.length = 256 := by + show (Aeneas.Std.Array.to_slice acc_final).length = 256 + rw [Aeneas.Std.Array.length_to_slice]; rfl + have h_acc_slice_len_val : acc_slice.val.length = 256 := by + show acc_slice.val.length = 256 + rw [h_acc_slice_val]; exact Std.Array.length_eq acc_final + -- (6) Array.index_mut_usize t_as_ntt k → (t_as_ntt[k.val]!, set t_as_ntt k). + set pre : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + t_as_ntt.val[k.val]! with h_pre_def + have h_idx_mut : Aeneas.Std.Array.index_mut_usize t_as_ntt k + = .ok (pre, t_as_ntt.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + have h_idx : Aeneas.Std.Array.index_usize t_as_ntt k = .ok pre := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq t_as_ntt k + (by rw [h_t_as_ntt_len]; exact h_lt) + rw [h_idx]; rfl + -- (7) Apply L6.7 (poly_reducing_from_i32_array_fc, strengthened) on acc_slice and pre. + have h_acc_zero_natAbs : ∀ n : Nat, n < 256 → + (acc_zero.val[n]!).val.natAbs = 0 := by + intro n hn + have h_eq := h_acc_zero_get n hn + rw [h_eq]; rfl + have h_acc_final_lane_bnd : ∀ n : Nat, n < 256 → + (acc_slice.val[n]!).val.natAbs ≤ 2^16 * 3328 := by + intro n hn + rw [h_acc_slice_val] + have h_b := h_acc_final_bnd n hn + -- h_b uses `[n]?.getD default`; rewrite via `List.getElem!_eq_getElem?_getD`. + simp only [← List.getElem!_eq_getElem?_getD] at h_b + have h_z := h_acc_zero_natAbs n hn + have hK4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + have h_bnd_2 : (4 : Nat) * 2^25 ≤ 2^16 * 3328 := by decide + omega + have h_l67 := + poly_reducing_from_i32_array_fc acc_slice pre h_acc_slice_len h_acc_final_lane_bnd + obtain ⟨t1, h_t1_eq, h_t1_post⟩ := triple_exists_ok_fc h_l67 + obtain ⟨h_t1_lift, h_t1_bnd⟩ := h_t1_post + -- (8) t_as_ntt1 := set t_as_ntt k t1. + set t_as_ntt1 : Stage3MontStripFC.TVec K := t_as_ntt.set k t1 with h_t_as_ntt1_def + have h_t_as_ntt1_at : t_as_ntt1.val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq t_as_ntt k k.val t1 + ⟨rfl, by rw [h_t_as_ntt_len]; exact h_lt⟩ + have h_t_as_ntt1_ne : ∀ j : Nat, j ≠ k.val → + t_as_ntt1.val[j]! = t_as_ntt.val[j]! := by + intro j hj + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne t_as_ntt k j t1 (fun h => hj h.symm) + have h_t_as_ntt1_len : t_as_ntt1.length = K.val := Std.Array.length_eq t_as_ntt1 + -- (9) Array.index_mut_usize t_as_ntt1 k → (t1, set t_as_ntt1 k). + have h_idx_mut1 : Aeneas.Std.Array.index_mut_usize t_as_ntt1 k + = .ok (t1, t_as_ntt1.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + have h_idx : Aeneas.Std.Array.index_usize t_as_ntt1 k = .ok t1 := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq t_as_ntt1 k + (by rw [h_t_as_ntt1_len]; exact h_lt) + rw [h_t_as_ntt1_at] at this + exact this + rw [h_idx]; rfl + -- (10) Array.index_usize error_as_ntt k → error_as_ntt[k.val]!. + set pre3 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + error_as_ntt.val[k.val]! with h_pre3_def + have h_idx_err : Aeneas.Std.Array.index_usize error_as_ntt k = .ok pre3 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq error_as_ntt k + (by rw [h_error_len]; exact h_lt) + -- (11) Apply L6.5 (add_standard_error_reduce_fc) on (t1, pre3). + have h_t1_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((t1.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro chunk hchunk ℓ hℓ + have h_b := h_t1_bnd chunk hchunk ℓ hℓ + -- h_b : … ≤ 4993; want ≤ 32767. + omega + have h_pre3_error_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((pre3.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439 := + fun chunk hchunk ℓ hℓ => + h_error_bnd ⟨k.val, h_lt⟩ ⟨chunk, hchunk⟩ ⟨ℓ, hℓ⟩ + have h_l65 := + add_standard_error_reduce_fc t1 pre3 h_t1_self_bnd h_pre3_error_bnd + obtain ⟨pre4, h_pre4_eq, h_pre4_post⟩ := triple_exists_ok_fc h_l65 + -- h_pre4_post : lift_poly pre4 = Spec.add_standard_error_reduce_pure (lift_poly t1) (lift_poly pre3). + -- (12) t_as_ntt_new := set t_as_ntt1 k pre4. + set t_as_ntt_new : Stage3MontStripFC.TVec K := t_as_ntt1.set k pre4 with h_t_as_ntt_new_def + have h_t_as_ntt_new_at : t_as_ntt_new.val[k.val]! = pre4 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq t_as_ntt1 k k.val pre4 + ⟨rfl, by rw [h_t_as_ntt1_len]; exact h_lt⟩ + have h_t_as_ntt_new_ne : ∀ j : Nat, j ≠ k.val → + t_as_ntt_new.val[j]! = t_as_ntt.val[j]! := by + intro j hj + have h1 : t_as_ntt_new.val[j]! = t_as_ntt1.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne t_as_ntt1 k j pre4 (fun h => hj h.symm) + rw [h1] + exact h_t_as_ntt1_ne j hj + -- (13) Body equation: reduce do-block to .ok (cont (s_iter_range, t_as_ntt_new, acc_final)). + have h_body : + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt error_as_ntt s_cache + { start := k, «end» := K } t_as_ntt accumulator + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), t_as_ntt_new, acc_final)) := by + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let i1 ← libcrux_secrets.traits.Classify.Blanket.classify (0#i32 : Std.I32) + let accumulator1 := Aeneas.Std.Array.repeat 256#usize i1 + let accumulator2 ← + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1_loop0 + (vectortraitsOperationsInst := portable_ops_inst) + { start := 0#usize, «end» := K } matrix_A s_as_ntt s_cache + accumulator1 k + let s ← Aeneas.Std.lift (Aeneas.Std.Array.to_slice accumulator2) + let (pre, index_mut_back) ← Aeneas.Std.Array.index_mut_usize t_as_ntt k + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let t_as_ntt1 := index_mut_back pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize t_as_ntt1 k + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt k + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 + let a := index_mut_back1 pre4 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), a, accumulator2))) + : Result _) = _ + rw [h_classify] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_acc_final_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let s ← Aeneas.Std.lift (Aeneas.Std.Array.to_slice acc_final) + let (pre, index_mut_back) ← Aeneas.Std.Array.index_mut_usize t_as_ntt k + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let t_as_ntt1 := index_mut_back pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize t_as_ntt1 k + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt k + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 + let a := index_mut_back1 pre4 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), a, acc_final))) + : Result _) = _ + show ((do + let s := Aeneas.Std.Array.to_slice acc_final + let (pre, index_mut_back) ← Aeneas.Std.Array.index_mut_usize t_as_ntt k + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let t_as_ntt1 := index_mut_back pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize t_as_ntt1 k + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt k + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 + let a := index_mut_back1 pre4 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), a, acc_final))) + : Result _) = _ + rw [h_idx_mut] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst (Aeneas.Std.Array.to_slice acc_final) pre + let t_as_ntt1 := t_as_ntt.set k pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize t_as_ntt1 k + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt k + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 + let a := index_mut_back1 pre4 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), a, acc_final))) + : Result _) = _ + have h_t1_eq' : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + (vectortraitsOperationsInst := portable_ops_inst) + (Aeneas.Std.Array.to_slice acc_final) pre = .ok t1 := h_t1_eq + rw [h_t1_eq'] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_mut1] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt k + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst t1 pre3 + let a := t_as_ntt1.set k pre4 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), a, acc_final))) + : Result _) = _ + rw [h_idx_err] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_pre4_eq] + simp only [Aeneas.Std.bind_tc_ok] + rfl + apply triple_of_ok_fc h_body + -- (14) Discharge the step_post. + show Stage3MontStripFC.rows_step_post matrix_A s_as_ntt error_as_ntt t_as_ntt_init start k + (.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), t_as_ntt_new, acc_final)) + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + -- (15) Re-establish `rows_inv` at s_iter (= k+1). + show (Stage3MontStripFC.rows_inv matrix_A s_as_ntt error_as_ntt t_as_ntt_init start + s_iter t_as_ntt_new acc_final).holds + unfold Stage3MontStripFC.rows_inv + show (pure _ : Result Prop).holds + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + have h_inv_pure : + (∀ r : Nat, start.val ≤ r → r < s_iter.val → ∀ ℓ : Nat, ℓ < 256 → + (lift_poly t_as_ntt_new.val[r]!).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt r (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly error_as_ntt.val[r]!).val[ℓ]!)) + ∧ (∀ r : Nat, r < K.val → (r < start.val ∨ s_iter.val ≤ r) → + t_as_ntt_new.val[r]! = t_as_ntt_init.val[r]!) := by + refine ⟨?_, ?_⟩ + · -- Conjunct (1): per-completed-row lane characterization. + intro r hr_ge hr_lt ℓ hℓ + rw [hs_iter_eq] at hr_lt + -- Case r < k.val (already-completed row, unchanged by this iteration) + -- vs r = k.val (the row we just wrote). + rcases Nat.lt_succ_iff_lt_or_eq.mp hr_lt with hr_lt_k | hr_eq_k + · -- r < k: row unchanged; use inv (1) at k. + have hr_ne : r ≠ k.val := by omega + rw [h_t_as_ntt_new_ne r hr_ne] + -- From the rows_inv (1) at k: (lift_poly t_as_ntt.val[r]!).val[ℓ]! = .... + have h_old := h_inv_done r hr_ge hr_lt_k ℓ hℓ + exact h_old + · -- r = k: row was written by this iteration. + subst hr_eq_k + rw [h_t_as_ntt_new_at] + -- Goal: (lift_poly pre4).val[ℓ]! = add_pure (mul_pure canonical_row_sum_lane 1353) (lift_poly error_as_ntt[r]).val[ℓ]!. + -- pre4 = output of L6.5 on (t1, pre3 = error_as_ntt[r]). + -- From h_pre4_post: + -- lift_poly pre4 = Spec.add_standard_error_reduce_pure (lift_poly t1) (lift_poly pre3). + -- Expand per-lane. + have hℓ_div_lt : ℓ / 16 < 16 := Nat.div_lt_iff_lt_mul (by decide : 0 < 16) |>.mpr hℓ + have hℓ_mod_lt : ℓ % 16 < 16 := Nat.mod_lt _ (by decide : 0 < 16) + have hℓ_decomp : 16 * (ℓ / 16) + ℓ % 16 = ℓ := by + have := Nat.div_add_mod ℓ 16 + omega + -- Step A: (lift_poly pre4).val[ℓ]! = + -- (chunk_add_standard_error_reduce_pure (chunk_at (lift_poly t1) (ℓ/16)) + -- (chunk_at (lift_poly pre3) (ℓ/16))).val[ℓ%16]! + have h_lift_pre4_lane : + (lift_poly pre4).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_poly t1).val[ℓ]!) (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly pre3).val[ℓ]!) := by + -- Use h_pre4_post (full equality) + flatten_chunks + chunk_at lane. + rw [h_pre4_post] + unfold Spec.add_standard_error_reduce_pure + unfold Spec.flatten_chunks + -- Goal: (Std.Array.make 256 ((List.range 256).map (fun j => ...)) _).val[ℓ]! = ... + show ((List.range 256).map (fun j => + ((Std.Array.make 16#usize ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk))) (by simp)).val[j / 16]!).val[j % 16]!))[ℓ]! + = _ + have h_len_outer : ((List.range 256).map (fun j => + ((Std.Array.make 16#usize ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk))) (by simp)).val[j / 16]!).val[j % 16]!)).length = 256 := by + simp + rw [getElem!_pos _ ℓ (by rw [h_len_outer]; exact hℓ)] + rw [List.getElem_map, List.getElem_range] + -- Now reduce the inner [ℓ/16]! lookup on the chunks list. + have h_chunks_at : + ((Std.Array.make 16#usize ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk))) (by simp)).val[ℓ / 16]!) + = Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) (ℓ / 16)) + (Spec.chunk_at (lift_poly pre3) (ℓ / 16)) := by + show ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk)))[ℓ / 16]! = _ + have h_len_inner : ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk))).length = 16 := by simp + rw [getElem!_pos _ (ℓ / 16) (by rw [h_len_inner]; exact hℓ_div_lt)] + rw [List.getElem_map, List.getElem_range] + rw [h_chunks_at] + -- Now reduce chunk_add_standard_error_reduce_pure ... .val[ℓ%16]!. + unfold Spec.chunk_add_standard_error_reduce_pure + show ((List.range 16).map (fun ℓ' => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.chunk_at (lift_poly t1) (ℓ / 16)).val[ℓ']!) + (lift_fe_mont (1353#i16 : Std.I16))) + ((Spec.chunk_at (lift_poly pre3) (ℓ / 16)).val[ℓ']!)))[ℓ % 16]! = _ + have h_len_chunk : ((List.range 16).map (fun ℓ' => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.chunk_at (lift_poly t1) (ℓ / 16)).val[ℓ']!) + (lift_fe_mont (1353#i16 : Std.I16))) + ((Spec.chunk_at (lift_poly pre3) (ℓ / 16)).val[ℓ']!))).length = 16 := by simp + rw [getElem!_pos _ (ℓ % 16) (by rw [h_len_chunk]; exact hℓ_mod_lt)] + rw [List.getElem_map, List.getElem_range] + -- Now we have: + -- add_pure (mul_pure (chunk_at (lift_poly t1) (ℓ/16)).val[ℓ%16]! (lift_fe_mont 1353)) + -- ((chunk_at (lift_poly pre3) (ℓ/16)).val[ℓ%16]!) + -- Need: chunk_at_lane = (lift_poly _).val[ℓ]!. + have h_t1_chunk_at : + (Spec.chunk_at (lift_poly t1) (ℓ / 16)).val[ℓ % 16]! + = (lift_poly t1).val[ℓ]! := by + unfold Spec.chunk_at + show ((List.range 16).map + (fun j => (lift_poly t1).val[16 * (ℓ / 16) + j]!))[ℓ % 16]! = _ + have h_len_chunk_at : ((List.range 16).map + (fun j => (lift_poly t1).val[16 * (ℓ / 16) + j]!)).length = 16 := by simp + rw [getElem!_pos _ (ℓ % 16) (by rw [h_len_chunk_at]; exact hℓ_mod_lt)] + rw [List.getElem_map, List.getElem_range, hℓ_decomp] + have h_pre3_chunk_at : + (Spec.chunk_at (lift_poly pre3) (ℓ / 16)).val[ℓ % 16]! + = (lift_poly pre3).val[ℓ]! := by + unfold Spec.chunk_at + show ((List.range 16).map + (fun j => (lift_poly pre3).val[16 * (ℓ / 16) + j]!))[ℓ % 16]! = _ + have h_len_chunk_at : ((List.range 16).map + (fun j => (lift_poly pre3).val[16 * (ℓ / 16) + j]!)).length = 16 := by simp + rw [getElem!_pos _ (ℓ % 16) (by rw [h_len_chunk_at]; exact hℓ_mod_lt)] + rw [List.getElem_map, List.getElem_range, hℓ_decomp] + rw [h_t1_chunk_at, h_pre3_chunk_at] + rw [h_lift_pre4_lane] + -- Step B: bridge `mul_pure ((lift_poly t1).val[ℓ]!) (lift_fe_mont 1353) = + -- mul_pure (canonical_row_sum_lane ...) (lift_fe_mont 1353)`. + -- Need: (lift_poly t1).val[ℓ]! = canonical_row_sum_lane matrix_A s_as_ntt k.val (ℓ/16) (ℓ%16). + -- Use h_t1_lift : lift_poly_mont t1 = Spec.poly_reducing_from_i32_array_pure acc_slice. + -- And h_acc_final_lane : per-(j, ℓ') foldl over [0, K.val) equation. + -- Step B.1: (lift_poly_mont t1).val[ℓ]! = mont_reduce_pure (lift_fe_int acc_final.val[ℓ]). + have h_lift_mont_t1_lane : + (lift_poly_mont t1).val[ℓ]! + = Spec.mont_reduce_pure (lift_fe_int (acc_final.val[ℓ]!).val) := by + rw [h_t1_lift] + unfold Spec.poly_reducing_from_i32_array_pure + show ((List.range 256).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (acc_slice.val[i]!).val)))[ℓ]! = _ + have h_len : ((List.range 256).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (acc_slice.val[i]!).val))).length = 256 := by simp + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map, List.getElem_range, h_acc_slice_val] + -- Step B.2: mont_reduce_pure (lift_fe_int acc_final[ℓ]) = canonical_row_sum (no outer × 1353). + have h_acc_final_at_ℓ : + Spec.mont_reduce_pure (lift_fe_int (acc_final.val[ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[k.val * K.val + c]!.coefficients.val[ℓ / 16]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[ℓ / 16]!)) + (Spec.zeta_at (64 + 4 * (ℓ / 16))) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3))).val[ℓ % 16]!)) + (Spec.mont_reduce_pure (lift_fe_int 0)) := by + have h_at := h_acc_final_lane (ℓ / 16) hℓ_div_lt (ℓ % 16) hℓ_mod_lt + simp only [← List.getElem!_eq_getElem?_getD] at h_at + rw [hℓ_decomp] at h_at + -- The init term in h_at is mont_reduce_pure (lift_fe_int (acc_zero[ℓ]).val). + -- acc_zero[ℓ] = 0, so .val = 0. + have h_zero_val : ((0#i32 : Std.I32).val) = 0 := rfl + rw [h_acc_zero_get ℓ hℓ] at h_at + rw [h_zero_val] at h_at + -- Convert (16 * (ℓ / 16) + ℓ % 16) to ℓ in the init too — already done by hℓ_decomp. + exact h_at + -- Step B.3: combine bridge with canonical_row_sum_lane definition. + have h_canon_unfold : + Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt k.val (ℓ / 16) (ℓ % 16) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[k.val * K.val + c]!.coefficients.val[ℓ / 16]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[ℓ / 16]!)) + (Spec.zeta_at (64 + 4 * (ℓ / 16))) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3))).val[ℓ % 16]!)) + (Spec.mont_reduce_pure (lift_fe_int 0))) + (lift_fe_mont (1353#i16 : Std.I16)) := by + -- canonical_row_sum_lane is `attribute [local irreducible]` per the section, + -- but unfolding via with_unfolding_all once is necessary here. + with_unfolding_all rfl + -- Now we need: (lift_poly t1).val[ℓ]! = canonical_row_sum_lane .... + -- We use the bridge `lift_poly_mont_to_lift_poly`: + -- mul_pure ((lift_poly_mont t1).val[ℓ]!) (lift_fe_mont 1353) = (lift_poly t1).val[ℓ]!. + have h_bridge : libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_poly_mont t1).val[ℓ]!) (lift_fe_mont (1353#i16 : Std.I16)) + = (lift_poly t1).val[ℓ]! := lift_poly_mont_to_lift_poly t1 ℓ hℓ + -- Combine: (lift_poly t1).val[ℓ]! = mul_pure ((lift_poly_mont t1).val[ℓ]!) (lift_fe_mont 1353) + -- = mul_pure (foldl ...) (lift_fe_mont 1353) + -- = canonical_row_sum_lane. + have h_lift_t1_lane : + (lift_poly t1).val[ℓ]! + = Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt k.val (ℓ / 16) (ℓ % 16) := by + rw [← h_bridge, h_lift_mont_t1_lane, h_acc_final_at_ℓ, h_canon_unfold] + -- Show pre3 = error_as_ntt.val[k.val]! by definitional unfolding (set above). + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_poly t1).val[ℓ]!) (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly pre3).val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt k.val (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly error_as_ntt.val[k.val]!).val[ℓ]!) + rw [h_lift_t1_lane] + · -- Conjunct (2): rows outside [start, s_iter) unchanged. + intro r hr_lt_K hr_disj + rw [hs_iter_eq] at hr_disj + -- r < start.val ∨ k+1 ≤ r — so in particular r ≠ k (since if r = k, hr_disj says r < start + -- which contradicts h_ge ≤ r, OR k+1 ≤ r = k which is false). + have hr_ne : r ≠ k.val := by + rcases hr_disj with hr_lt_start | hr_ge_succ + · -- r < start ≤ k, so r ≠ k. + omega + · -- k+1 ≤ r, so r > k. + omega + rw [h_t_as_ntt_new_ne r hr_ne] + -- Now need: t_as_ntt.val[r]! = t_as_ntt_init.val[r]!. + apply h_inv_undone r hr_lt_K + rcases hr_disj with hr_lt_start | hr_ge_succ + · exact Or.inl hr_lt_start + · -- k+1 ≤ r, so k ≤ r. + exact Or.inr (by omega) + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _; exact h_inv_pure + · -- `None` branch: k ≥ K, done. + have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt error_as_ntt s_cache + { start := k, «end» := K } t_as_ntt accumulator + = .ok (ControlFlow.done (t_as_ntt, accumulator)) := by + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show Stage3MontStripFC.rows_step_post matrix_A s_as_ntt error_as_ntt t_as_ntt_init start k + (.done (t_as_ntt, accumulator)) + show (Stage3MontStripFC.rows_inv matrix_A s_as_ntt error_as_ntt t_as_ntt_init start + K t_as_ntt accumulator).holds + unfold Stage3MontStripFC.rows_inv + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ r : Nat, start.val ≤ r → r < K.val → ∀ ℓ : Nat, ℓ < 256 → + (lift_poly t_as_ntt.val[r]!).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt r (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly error_as_ntt.val[r]!).val[ℓ]!)) + ∧ (∀ r : Nat, r < K.val → (r < start.val ∨ K.val ≤ r) → + t_as_ntt.val[r]! = t_as_ntt_init.val[r]!) := by + refine ⟨?_, ?_⟩ + · intro r hr_ge hr_lt ℓ hℓ + have h_eq := h_inv_done r hr_ge (by rw [hk_eq]; exact hr_lt) ℓ hℓ + exact h_eq + · intro r hr_lt_K hr_disj + exact h_inv_undone r hr_lt_K (by + rcases hr_disj with hl | hr + · exact Or.inl hl + · exact Or.inr (by rw [hk_eq]; exact hr)) + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _; exact h_inv_pure + +/-- L7.1 Stage 3 — `matrix.compute_As_plus_e_loop1`: the outer rows loop over + `i ∈ [start, K)`. Each iteration re-zeros the accumulator, calls Stage 2 + (`compute_As_plus_e_loop1_loop0_fc`) for the row's column sum, converts via + L6.7 (`poly_reducing_from_i32_array_fc`) to Mont FE form, then applies + L6.5 (`add_standard_error_reduce_fc`) (× lift_fe_mont 1353 + error) to + produce the canonical FE row. The Mont→canonical bridge step (× 1353) is + absorbed inside `canonical_row_sum_lane`; the OUTER × 1353 in the + invariant comes from L6.5's own mul step. + + POST: `rows_inv` holds at k = K — i.e. for each row `r ∈ [start, K)`, + every lane `ℓ ∈ [0, 256)` equals + `add_pure (mul_pure (canonical_row_sum_lane matrix_A s_as_ntt r (ℓ/16) (ℓ%16)) + (lift_fe_mont 1353)) + ((lift_poly error_as_ntt.val[r]!).val[ℓ]!)`, + and rows outside `[start, K)` are unchanged from `t_as_ntt_init`. + + PRE: standard 16×16 bounds (3328/3328/29439) on matrix/s/error entries, + `hAlen : matrix_A.length = K·K`, `hK : K.val ≤ 4` (drives Stage 2's + K·2^25 ≤ 2^27 ≤ 2^16·3328 reasoning at L6.7), `h_start_le_K : start ≤ K`, + and the cache-post hypothesis `h_cache` — at every column `c < K`, + `s_cache.val[c]!` satisfies `accumulating_ntt_multiply_poly_cache_post` + against `s_as_ntt.val[c]!`. The accumulator passed in is rewritten on each + iteration (Stage 2 starts from a freshly-zeroed accumulator). -/ +@[spec] +theorem compute_As_plus_e_loop1_fc + {K : Std.Usize} + (t_as_ntt_init : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt error_as_ntt s_cache : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (accumulator : Std.Array Std.I32 256#usize) + (start : Std.Usize) + (hK : K.val ≤ 4) + (h_start_le_K : start.val ≤ K.val) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) + (h_matrix_bnd : ∀ k : Fin matrix_A.length, ∀ a b : Fin 16, + ((matrix_A.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_s_bnd : ∀ k : Fin K.val, ∀ a b : Fin 16, + ((s_as_ntt.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_error_bnd : ∀ k : Fin K.val, ∀ a b : Fin 16, + ((error_as_ntt.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 29439) + (h_cache : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (s_as_ntt.val[c]!) (s_cache.val[c]!)) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1 + (vectortraitsOperationsInst := portable_ops_inst) + { start := start, «end» := K } t_as_ntt_init matrix_A s_as_ntt error_as_ntt s_cache accumulator + ⦃ ⇓ p => ⌜ (Stage3MontStripFC.rows_inv matrix_A s_as_ntt error_as_ntt t_as_ntt_init start K p.1 p.2).holds ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, p) => + libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1.body + (vectortraitsOperationsInst := portable_ops_inst) matrix_A s_as_ntt error_as_ntt s_cache + iter1 p.1 p.2) + (β := Stage3MontStripFC.TVec K × Stage3MontStripFC.Acc) + (t_as_ntt_init, accumulator) + start K + (fun k p => Stage3MontStripFC.rows_inv matrix_A s_as_ntt error_as_ntt t_as_ntt_init start k p.1 p.2) + h_start_le_K + (by + -- Base case at k = start: rows_inv holds trivially. + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_⟩ + · -- (1) Vacuous: r ∈ [start, start) is empty. + intro r hr_ge hr_lt _ _; omega + · -- (2) For r ∉ [start, start), trivially t_as_ntt_init[r] = t_as_ntt_init[r]. + intro r _ _; trivial) + ?_) + · -- Post entailment: at k = K, rows_inv holds. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Stage3MontStripFC.rows_inv matrix_A s_as_ntt error_as_ntt t_as_ntt_init start + K r.1 r.2).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + exact h_inv_holds + · -- Step entailment. + intro p k h_ge h_le hinv + have h_step := compute_As_plus_e_loop1_step_lemma_fc + matrix_A s_as_ntt error_as_ntt s_cache t_as_ntt_init start hK hAlen + h_matrix_bnd h_s_bnd h_error_bnd h_cache p.1 p.2 k h_ge h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', t_acc⟩ | y + · have hP : Stage3MontStripFC.rows_step_post matrix_A s_as_ntt error_as_ntt t_as_ntt_init start k + (.cont (iter', t_acc.1, t_acc.2)) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Stage3MontStripFC.rows_step_post] using hP + · have hP : Stage3MontStripFC.rows_step_post matrix_A s_as_ntt error_as_ntt t_as_ntt_init start k + (.done (y.1, y.2)) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Stage3MontStripFC.rows_step_post] using hP + +set_option maxHeartbeats 16000000 in +/-- L7.1 Stage 4a — row-0 finalization helper. + + Given a row-0 column-loop output `accumulator` satisfying `Stage1FillCacheFC.row0_inv` + at k=K (with `acc_init = accumulator` itself — i.e. the lemma's caller passes + the original L7.1 accumulator and the Stage 1 output coincides at this slot, + consistent with the calling pattern at `compute_As_plus_e_loop0_fc`), executes the bind chain + `Array.to_slice + index_mut_usize t_as_ntt 0 + L6.7 + index_mut t_as_ntt1 0 + + index_usize error_as_ntt 0 + L6.5` and produces `a` such that: + + (1) For row 0, every lane ℓ < 256: + `(lift_poly a.val[0]!).val[ℓ]! + = add_pure + (mul_pure (canonical_row_sum_lane matrix_A s_as_ntt 0 (ℓ/16) (ℓ%16)) + (lift_fe_mont 1353)) + ((lift_poly error_as_ntt.val[0]!).val[ℓ]!)`. + (2) For rows r > 0: `a.val[r]! = t_as_ntt.val[r]!`. + + Mirrors `compute_As_plus_e_loop1_step_lemma_fc` structurally + — the `.cont` branch's bind chain at lines 30880-31320, with row index `k` + replaced by `0#usize` and matrix lane index `k.val * K.val + c` replaced by + just `c` (since `0 * K + c = c`). + + Extra PRE beyond the template: `h_acc_zero` collapses row0_inv's foldl seed + `mont_reduce_pure (lift_fe_int accumulator[16j+ℓ].val)` to + `mont_reduce_pure (lift_fe_int 0)` — matching `canonical_row_sum_lane`'s + init. `h_acc_lane_bnd` discharges the L6.7 PRE on the slice. -/ +theorem compute_As_plus_e_row0_finalize_fc + {K : Std.Usize} + (t_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt error_as_ntt s_cache : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init accumulator : Stage3MontStripFC.Acc) + (s_cache_fin : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (hK_pos : 0 < K.val) + (h_error_bnd : ∀ k : Fin K.val, ∀ a b : Fin 16, + ((error_as_ntt.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 29439) + (h_acc_zero : ∀ n : Nat, n < 256 → acc_init.val[n]! = (0#i32 : Std.I32)) + (h_acc_lane_bnd : ∀ n : Nat, n < 256 → + (accumulator.val[n]!).val.natAbs ≤ 2^16 * 3328) + (h_row0_inv : (Stage1FillCacheFC.row0_inv matrix_A s_as_ntt acc_init s_cache K accumulator + s_cache_fin).holds) : + ⦃ ⌜ True ⌝ ⦄ + (do + let s ← Aeneas.Std.lift (Aeneas.Std.Array.to_slice accumulator) + let (pre, index_mut_back) ← Aeneas.Std.Array.index_mut_usize t_as_ntt 0#usize + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let t_as_ntt1 := index_mut_back pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize t_as_ntt1 0#usize + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt 0#usize + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 + .ok (index_mut_back1 pre4) : + Result (Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K)) + ⦃ ⇓ a => ⌜ + (∀ ℓ : Nat, ℓ < 256 → + (lift_poly a.val[0]!).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt 0 (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly error_as_ntt.val[0]!).val[ℓ]!)) + ∧ (∀ r : Nat, 0 < r → r < K.val → + a.val[r]! = t_as_ntt.val[r]!) ⌝ ⦄ := by + have h_t_as_ntt_len : t_as_ntt.length = K.val := Std.Array.length_eq t_as_ntt + have h_error_len : error_as_ntt.length = K.val := Std.Array.length_eq error_as_ntt + -- Convenience: (0#usize).val = 0. + have h_zero_val : (0#usize : Std.Usize).val = 0 := rfl + -- Destructure the 4-conjunct row0_inv (we only need (1)). + obtain ⟨h_row0_lane, _h_row0_bnd, _h_row0_cache_pop, _h_row0_cache_unch⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_row0_inv + -- (1) acc_slice := Array.to_slice accumulator. + set acc_slice : Slice Std.I32 := Aeneas.Std.Array.to_slice accumulator with h_acc_slice_def + have h_acc_slice_val : acc_slice.val = accumulator.val := + Aeneas.Std.Array.val_to_slice accumulator + have h_acc_slice_len : acc_slice.length = 256 := by + show (Aeneas.Std.Array.to_slice accumulator).length = 256 + rw [Aeneas.Std.Array.length_to_slice]; rfl + -- (2) Array.index_mut_usize t_as_ntt 0 → (t_as_ntt[0]!, set t_as_ntt 0). + set pre : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + t_as_ntt.val[(0#usize : Std.Usize).val]! with h_pre_def + have h_idx_mut : Aeneas.Std.Array.index_mut_usize t_as_ntt (0#usize : Std.Usize) + = .ok (pre, t_as_ntt.set (0#usize : Std.Usize)) := by + unfold Aeneas.Std.Array.index_mut_usize + have h_idx : Aeneas.Std.Array.index_usize t_as_ntt (0#usize : Std.Usize) = .ok pre := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq t_as_ntt (0#usize : Std.Usize) + (by rw [h_t_as_ntt_len]; exact hK_pos) + rw [h_idx]; rfl + -- (3) Apply L6.7 on acc_slice + pre. Use h_acc_lane_bnd via h_acc_slice_val. + have h_acc_slice_lane_bnd : ∀ n : Nat, n < 256 → + (acc_slice.val[n]!).val.natAbs ≤ 2^16 * 3328 := by + intro n hn; rw [h_acc_slice_val]; exact h_acc_lane_bnd n hn + have h_l67 := + poly_reducing_from_i32_array_fc acc_slice pre h_acc_slice_len h_acc_slice_lane_bnd + obtain ⟨t1, h_t1_eq, h_t1_post⟩ := triple_exists_ok_fc h_l67 + obtain ⟨h_t1_lift, h_t1_bnd⟩ := h_t1_post + -- (4) t_as_ntt1 := set t_as_ntt 0 t1. + set t_as_ntt1 : Stage3MontStripFC.TVec K := t_as_ntt.set (0#usize : Std.Usize) t1 + with h_t_as_ntt1_def + have h_t_as_ntt1_at : t_as_ntt1.val[(0#usize : Std.Usize).val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq t_as_ntt (0#usize : Std.Usize) + (0#usize : Std.Usize).val t1 + ⟨rfl, by rw [h_t_as_ntt_len]; exact hK_pos⟩ + have h_t_as_ntt1_ne : ∀ j : Nat, j ≠ (0#usize : Std.Usize).val → + t_as_ntt1.val[j]! = t_as_ntt.val[j]! := by + intro j hj + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne t_as_ntt (0#usize : Std.Usize) j t1 + (fun h => hj h.symm) + have h_t_as_ntt1_len : t_as_ntt1.length = K.val := Std.Array.length_eq t_as_ntt1 + -- (5) Array.index_mut_usize t_as_ntt1 0 → (t1, set t_as_ntt1 0). + have h_idx_mut1 : Aeneas.Std.Array.index_mut_usize t_as_ntt1 (0#usize : Std.Usize) + = .ok (t1, t_as_ntt1.set (0#usize : Std.Usize)) := by + unfold Aeneas.Std.Array.index_mut_usize + have h_idx : Aeneas.Std.Array.index_usize t_as_ntt1 (0#usize : Std.Usize) = .ok t1 := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq t_as_ntt1 (0#usize : Std.Usize) + (by rw [h_t_as_ntt1_len]; exact hK_pos) + rw [h_t_as_ntt1_at] at this + exact this + rw [h_idx]; rfl + -- (6) Array.index_usize error_as_ntt 0 → error_as_ntt[0]!. + set pre3 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + error_as_ntt.val[(0#usize : Std.Usize).val]! with h_pre3_def + have h_idx_err : Aeneas.Std.Array.index_usize error_as_ntt (0#usize : Std.Usize) = .ok pre3 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq error_as_ntt (0#usize : Std.Usize) + (by rw [h_error_len]; exact hK_pos) + -- (7) Apply L6.5 on (t1, pre3). + have h_t1_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((t1.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro chunk hchunk ℓ hℓ + have h_b := h_t1_bnd chunk hchunk ℓ hℓ + omega + have h_pre3_error_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((pre3.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439 := + fun chunk hchunk ℓ hℓ => + h_error_bnd ⟨(0#usize : Std.Usize).val, hK_pos⟩ ⟨chunk, hchunk⟩ ⟨ℓ, hℓ⟩ + have h_l65 := + add_standard_error_reduce_fc t1 pre3 h_t1_self_bnd h_pre3_error_bnd + obtain ⟨pre4, h_pre4_eq, h_pre4_post⟩ := triple_exists_ok_fc h_l65 + -- (8) t_as_ntt_new := set t_as_ntt1 0 pre4. + set t_as_ntt_new : Stage3MontStripFC.TVec K := t_as_ntt1.set (0#usize : Std.Usize) pre4 + with h_t_as_ntt_new_def + have h_t_as_ntt_new_at : t_as_ntt_new.val[0]! = pre4 := by + have h := Aeneas.Std.Array.getElem!_Nat_set_eq t_as_ntt1 (0#usize : Std.Usize) + 0 pre4 ⟨rfl, by rw [h_t_as_ntt1_len]; exact hK_pos⟩ + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h + have h_t_as_ntt_new_ne : ∀ j : Nat, j ≠ 0 → + t_as_ntt_new.val[j]! = t_as_ntt.val[j]! := by + intro j hj + have h1 : t_as_ntt_new.val[j]! = t_as_ntt1.val[j]! := by + have := Aeneas.Std.Array.getElem!_Nat_set_ne t_as_ntt1 (0#usize : Std.Usize) j pre4 + (fun h => hj (by rw [← h_zero_val]; exact h.symm)) + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using this + rw [h1] + exact h_t_as_ntt1_ne j (by rw [h_zero_val]; exact hj) + -- (9) Body equation: reduce do-block to .ok t_as_ntt_new. + have h_body : + (do + let s ← Aeneas.Std.lift (Aeneas.Std.Array.to_slice accumulator) + let (pre, index_mut_back) ← Aeneas.Std.Array.index_mut_usize t_as_ntt 0#usize + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let t_as_ntt1 := index_mut_back pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize t_as_ntt1 0#usize + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt 0#usize + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 + .ok (index_mut_back1 pre4) : + Result (Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K)) + = .ok t_as_ntt_new := by + show ((do + let s := Aeneas.Std.Array.to_slice accumulator + let (pre, index_mut_back) ← Aeneas.Std.Array.index_mut_usize t_as_ntt 0#usize + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let t_as_ntt1 := index_mut_back pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize t_as_ntt1 0#usize + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt 0#usize + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 + .ok (index_mut_back1 pre4)) + : Result _) = _ + rw [h_idx_mut] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst (Aeneas.Std.Array.to_slice accumulator) pre + let t_as_ntt1 := t_as_ntt.set (0#usize : Std.Usize) pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize t_as_ntt1 0#usize + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt 0#usize + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 + .ok (index_mut_back1 pre4)) + : Result _) = _ + have h_t1_eq' : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + (vectortraitsOperationsInst := portable_ops_inst) + (Aeneas.Std.Array.to_slice accumulator) pre = .ok t1 := h_t1_eq + rw [h_t1_eq'] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_mut1] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt 0#usize + let pre4 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst t1 pre3 + .ok (t_as_ntt1.set (0#usize : Std.Usize) pre4)) + : Result _) = _ + rw [h_idx_err] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_pre4_eq] + simp only [Aeneas.Std.bind_tc_ok] + rfl + apply triple_of_ok_fc h_body + -- (10) Discharge the 2-conjunct post. + refine ⟨?_, ?_⟩ + · -- Conjunct (1): per-lane characterization at row 0. + intro ℓ hℓ + rw [h_t_as_ntt_new_at] + -- Now: (lift_poly pre4).val[ℓ]! = add_pure (mul_pure canonical_row_sum_lane 1353) + -- ((lift_poly error_as_ntt[0]!).val[ℓ]!). + have hℓ_div_lt : ℓ / 16 < 16 := Nat.div_lt_iff_lt_mul (by decide : 0 < 16) |>.mpr hℓ + have hℓ_mod_lt : ℓ % 16 < 16 := Nat.mod_lt _ (by decide : 0 < 16) + have hℓ_decomp : 16 * (ℓ / 16) + ℓ % 16 = ℓ := by + have := Nat.div_add_mod ℓ 16 + omega + -- Step A: (lift_poly pre4).val[ℓ]! = add_pure (mul_pure ((lift_poly t1).val[ℓ]!) 1353) + -- ((lift_poly pre3).val[ℓ]!). + have h_lift_pre4_lane : + (lift_poly pre4).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_poly t1).val[ℓ]!) (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly pre3).val[ℓ]!) := by + rw [h_pre4_post] + unfold Spec.add_standard_error_reduce_pure + unfold Spec.flatten_chunks + show ((List.range 256).map (fun j => + ((Std.Array.make 16#usize ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk))) (by simp)).val[j / 16]!).val[j % 16]!))[ℓ]! + = _ + have h_len_outer : ((List.range 256).map (fun j => + ((Std.Array.make 16#usize ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk))) (by simp)).val[j / 16]!).val[j % 16]!)).length = 256 := by + simp + rw [getElem!_pos _ ℓ (by rw [h_len_outer]; exact hℓ)] + rw [List.getElem_map, List.getElem_range] + have h_chunks_at : + ((Std.Array.make 16#usize ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk))) (by simp)).val[ℓ / 16]!) + = Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) (ℓ / 16)) + (Spec.chunk_at (lift_poly pre3) (ℓ / 16)) := by + show ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk)))[ℓ / 16]! = _ + have h_len_inner : ((List.range 16).map (fun kk => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly t1) kk) + (Spec.chunk_at (lift_poly pre3) kk))).length = 16 := by simp + rw [getElem!_pos _ (ℓ / 16) (by rw [h_len_inner]; exact hℓ_div_lt)] + rw [List.getElem_map, List.getElem_range] + rw [h_chunks_at] + unfold Spec.chunk_add_standard_error_reduce_pure + show ((List.range 16).map (fun ℓ' => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.chunk_at (lift_poly t1) (ℓ / 16)).val[ℓ']!) + (lift_fe_mont (1353#i16 : Std.I16))) + ((Spec.chunk_at (lift_poly pre3) (ℓ / 16)).val[ℓ']!)))[ℓ % 16]! = _ + have h_len_chunk : ((List.range 16).map (fun ℓ' => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.chunk_at (lift_poly t1) (ℓ / 16)).val[ℓ']!) + (lift_fe_mont (1353#i16 : Std.I16))) + ((Spec.chunk_at (lift_poly pre3) (ℓ / 16)).val[ℓ']!))).length = 16 := by simp + rw [getElem!_pos _ (ℓ % 16) (by rw [h_len_chunk]; exact hℓ_mod_lt)] + rw [List.getElem_map, List.getElem_range] + have h_t1_chunk_at : + (Spec.chunk_at (lift_poly t1) (ℓ / 16)).val[ℓ % 16]! + = (lift_poly t1).val[ℓ]! := by + unfold Spec.chunk_at + show ((List.range 16).map + (fun j => (lift_poly t1).val[16 * (ℓ / 16) + j]!))[ℓ % 16]! = _ + have h_len_chunk_at : ((List.range 16).map + (fun j => (lift_poly t1).val[16 * (ℓ / 16) + j]!)).length = 16 := by simp + rw [getElem!_pos _ (ℓ % 16) (by rw [h_len_chunk_at]; exact hℓ_mod_lt)] + rw [List.getElem_map, List.getElem_range, hℓ_decomp] + have h_pre3_chunk_at : + (Spec.chunk_at (lift_poly pre3) (ℓ / 16)).val[ℓ % 16]! + = (lift_poly pre3).val[ℓ]! := by + unfold Spec.chunk_at + show ((List.range 16).map + (fun j => (lift_poly pre3).val[16 * (ℓ / 16) + j]!))[ℓ % 16]! = _ + have h_len_chunk_at : ((List.range 16).map + (fun j => (lift_poly pre3).val[16 * (ℓ / 16) + j]!)).length = 16 := by simp + rw [getElem!_pos _ (ℓ % 16) (by rw [h_len_chunk_at]; exact hℓ_mod_lt)] + rw [List.getElem_map, List.getElem_range, hℓ_decomp] + rw [h_t1_chunk_at, h_pre3_chunk_at] + rw [h_lift_pre4_lane] + -- Step B: (lift_poly t1).val[ℓ]! = canonical_row_sum_lane matrix_A s_as_ntt 0 (ℓ/16) (ℓ%16). + have h_lift_mont_t1_lane : + (lift_poly_mont t1).val[ℓ]! + = Spec.mont_reduce_pure (lift_fe_int (accumulator.val[ℓ]!).val) := by + rw [h_t1_lift] + unfold Spec.poly_reducing_from_i32_array_pure + show ((List.range 256).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (acc_slice.val[i]!).val)))[ℓ]! = _ + have h_len : ((List.range 256).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (acc_slice.val[i]!).val))).length = 256 := by simp + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map, List.getElem_range, h_acc_slice_val] + -- Step B.2: mont_reduce_pure (lift_fe_int accumulator[ℓ]) = foldl (no outer × 1353) + -- with seed mont_reduce_pure (lift_fe_int 0). + -- The row0_inv conjunct (1) at j = ℓ/16, ℓ' = ℓ%16: + -- mont_reduce_pure (lift_fe_int (accumulator[16*(ℓ/16)+ℓ%16]).val) + -- = (List.range K.val).foldl ... (mont_reduce_pure (lift_fe_int (accumulator[16*(ℓ/16)+ℓ%16]).val)) + -- where the foldl uses matrix_A[c] (not k*K+c) and s_as_ntt[c]. + -- h_acc_zero collapses the seed via accumulator[ℓ] = 0#i32 → .val = 0. + have h_acc_at_ℓ : + Spec.mont_reduce_pure (lift_fe_int (accumulator.val[ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[c]!.coefficients.val[ℓ / 16]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[ℓ / 16]!)) + (Spec.zeta_at (64 + 4 * (ℓ / 16))) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3))).val[ℓ % 16]!)) + (Spec.mont_reduce_pure (lift_fe_int 0)) := by + have h_at := h_row0_lane (ℓ / 16) hℓ_div_lt (ℓ % 16) hℓ_mod_lt + rw [hℓ_decomp] at h_at + -- h_at: mont_reduce_pure (lift_fe_int (acc_final[ℓ]).val) + -- = foldl ... (mont_reduce_pure (lift_fe_int (acc_init[ℓ]).val)) + -- Goal seed: mont_reduce_pure (lift_fe_int 0). The acc_init seed collapses to 0 + -- via h_acc_zero (acc_init is the loop0 INPUT accumulator, which is zero). + have h_z := h_acc_zero ℓ hℓ + have h_zero_i32_val : ((0#i32 : Std.I32).val) = 0 := rfl + have h_collapse : (acc_init.val[ℓ]!).val = 0 := by rw [h_z]; exact h_zero_i32_val + rw [h_collapse] at h_at + exact h_at + -- Step B.3: combine via canonical_row_sum_lane. + -- For row i = 0, canonical_row_sum_lane's internal index 0*K.val + c = c. + have h_canon_unfold : + Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt 0 (ℓ / 16) (ℓ % 16) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[0 * K.val + c]!.coefficients.val[ℓ / 16]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[ℓ / 16]!)) + (Spec.zeta_at (64 + 4 * (ℓ / 16))) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3))).val[ℓ % 16]!)) + (Spec.mont_reduce_pure (lift_fe_int 0))) + (lift_fe_mont (1353#i16 : Std.I16)) := by + with_unfolding_all rfl + -- 0 * K.val + c = c, so the matrix index in h_canon_unfold matches h_acc_at_ℓ. + have h_zero_K_c : ∀ c : Nat, 0 * K.val + c = c := by intro c; omega + -- Bridge: lift_poly_mont_to_lift_poly converts the Mont-lift to the canonical lift. + have h_bridge : libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_poly_mont t1).val[ℓ]!) (lift_fe_mont (1353#i16 : Std.I16)) + = (lift_poly t1).val[ℓ]! := lift_poly_mont_to_lift_poly t1 ℓ hℓ + have h_lift_t1_lane : + (lift_poly t1).val[ℓ]! + = Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt 0 (ℓ / 16) (ℓ % 16) := by + rw [← h_bridge, h_lift_mont_t1_lane, h_acc_at_ℓ, h_canon_unfold] + -- Both sides have the same foldl, but matrix indices differ by 0*K.val+c = c. + -- Since 0*K.val = 0, this is just c = c. + simp only [Nat.zero_mul, Nat.zero_add] + -- pre3 = error_as_ntt.val[0]! (definitionally, since (0#usize).val = 0). + -- Goal: add_pure (mul_pure (lift_poly t1)[ℓ] 1353) (lift_poly pre3)[ℓ] + -- = add_pure (mul_pure (canonical_row_sum_lane ...) 1353) (lift_poly error_as_ntt[0])[ℓ] + rw [h_lift_t1_lane] + -- Now match pre3 = error_as_ntt.val[0]! on the RHS. + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt 0 (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly (error_as_ntt.val[(0#usize : Std.Usize).val]!)).val[ℓ]!) + = _ + rfl + · -- Conjunct (2): rows r > 0 unchanged. + intro r hr_pos _hr_lt_K + have hr_ne : r ≠ 0 := by omega + exact h_t_as_ntt_new_ne r hr_ne + +end L7_1c_irreducible + +/-! ## §L7.1 Stage 4b — hacspec-side bridge lemma. + + Given the per-row, per-lane characterization of `t_as_ntt_final`, + proves the hacspec `compute_As_plus_e` equation that L7.1's POST + demands. The proof unfolds `compute_As_plus_e` to its do-block + `multiply_matrix_by_column ; add_vectors` and uses three layered + `from_fn_pure_eq` applications. -/ + +namespace Stage4MatrixAddFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Clone of `polynomial.add_to_ring_element_eq_ok` for the byte-identical + `matrix.add_polynomials` closure (both compile from the same Rust + source pattern; the closure bodies are identical up to namespace). -/ +-- Public (exported for L7.4 `compute_message_acc_bridge`): per-step reduction of +-- the hacspec `matrix.add_polynomials` to its pure-lane `add_pure` array form. +-- Visibility-only change (proof/statement unchanged). +theorem matrix_add_polynomials_eq_ok + (lhs rhs : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + hacspec_ml_kem.matrix.add_polynomials lhs rhs + = .ok ⟨(List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lhs.val[k]!) (rhs.val[k]!)), + by simp [List.length_map, List.length_range]⟩ := by + set f : Nat → hacspec_ml_kem.parameters.FieldElement := + fun k => libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lhs.val[k]!) (rhs.val[k]!) with hf_def + have hpure : ∀ k : Nat, k < (256#usize : Std.Usize).val → + (hacspec_ml_kem.matrix.add_polynomials.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + : CoreModels.core.ops.function.Fn _ _ _).FnMutInst.call_mut + (lhs, rhs) ⟨BitVec.ofNat _ k⟩ + = .ok (f k, (lhs, rhs)) := by + intro k hk + have hk' : k < 256 := hk + show hacspec_ml_kem.matrix.add_polynomials.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + (lhs, rhs) ⟨BitVec.ofNat _ k⟩ = .ok (f k, (lhs, rhs)) + unfold hacspec_ml_kem.matrix.add_polynomials.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + unfold hacspec_ml_kem.matrix.add_polynomials.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement.call + have hk_us : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + have : k < 2^System.Platform.numBits := by + have hbits : 2^16 ≤ 2^System.Platform.numBits := + Nat.pow_le_pow_right (by decide) (by + cases System.Platform.numBits_eq with + | inl h => rw [h]; decide + | inr h => rw [h]; decide) + omega + exact this + have hlhs_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < lhs.length := by + rw [hk_us]; show k < lhs.val.length + rw [lhs.property]; exact hk + have hrhs_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < rhs.length := by + rw [hk_us]; show k < rhs.val.length + rw [rhs.property]; exact hk + have h_lhs_idx : + Std.Array.index_usize lhs (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (lhs.val[k]!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq lhs + (⟨BitVec.ofNat _ k⟩ : Std.Usize) hlhs_len + rw [hk_us] at this; exact this + have h_rhs_idx : + Std.Array.index_usize rhs (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (rhs.val[k]!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq rhs + (⟨BitVec.ofNat _ k⟩ : Std.Usize) hrhs_len + rw [hk_us] at this; exact this + have h_add := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_eq_ok (lhs.val[k]!) (rhs.val[k]!) + change (do + let fe ← (do + let fe ← Std.Array.index_usize lhs ⟨BitVec.ofNat _ k⟩ + let i ← lift (Std.UScalar.cast .U32 fe.val) + let fe1 ← Std.Array.index_usize rhs ⟨BitVec.ofNat _ k⟩ + let i1 ← lift (Std.UScalar.cast .U32 fe1.val) + let i2 ← i + i1 + let i3 ← lift (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS) + let i4 ← i2 % i3 + let i5 ← lift (Std.UScalar.cast .U16 i4) + hacspec_ml_kem.parameters.FieldElement.new i5) + Result.ok (fe, lhs, rhs)) = Result.ok (f k, lhs, rhs) + rw [h_lhs_idx]; simp only [bind_tc_ok] + rw [h_rhs_idx]; simp only [bind_tc_ok] + unfold hacspec_ml_kem.parameters.FieldElement.add at h_add + rw [h_add] + simp only [bind_tc_ok, hf_def] + have h_from_fn := + libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq + (T := hacspec_ml_kem.parameters.FieldElement) + (F := hacspec_ml_kem.matrix.add_polynomials.closure) + (N := 256#usize) + (inst := hacspec_ml_kem.matrix.add_polynomials.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement) + (c := (lhs, rhs)) + (f := f) + hpure + unfold hacspec_ml_kem.matrix.add_polynomials + unfold hacspec_ml_kem.parameters.createi + show core.array.from_fn 256#usize _ (lhs, rhs) = _ + exact h_from_fn + +/-- **Helper 1.** Single-product lane-eq characterization: for any two + `PolynomialRingElement`s `a, b`, lane `ℓ` of `Spec.multiply_ntts_pure + (lift_poly a) (lift_poly b)` equals the corresponding chunk-projected + `Spec.ntt_multiply_pure_no_acc` lane on `lift_chunk_mont`-style chunks, + scaled by `(lift_fe_mont 1353)²`. + + Composes `Spec.multiply_ntts_pure_eq_chunked_no_acc` + with Helper 2 `Stage3MontStripFC.ntt_multiply_pure_no_acc_chunk_at_lift_poly_eq`. -/ +theorem multiply_ntts_lane_eq_canonical_factor + (a b : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (ℓ : Nat) (hℓ : ℓ < 256) : + (Spec.multiply_ntts_pure (lift_poly a) (lift_poly b)).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont a.coefficients.val[ℓ / 16]!) + (lift_chunk_mont b.coefficients.val[ℓ / 16]!) + (Spec.zeta_at (64 + 4 * (ℓ / 16))) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3))).val[ℓ % 16]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (1353#i16 : Std.I16)) + (lift_fe_mont (1353#i16 : Std.I16))) := by + have h_div_lt : ℓ / 16 < 16 := by omega + have h_mod_lt : ℓ % 16 < 16 := Nat.mod_lt _ (by decide) + -- Step 1: rewrite Spec.multiply_ntts_pure via the chunked form. + rw [Spec.multiply_ntts_pure_eq_chunked_no_acc] + -- Step 2: project lane ℓ through flatten_chunks via the inner list. + unfold Spec.flatten_chunks + -- The Std.Array.make's `.val[k]!` lookup is direct list-lookup. + show ((List.range 256).map (fun j => + (((List.range 16).map (fun j' => + Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at (lift_poly a) j') (Spec.chunk_at (lift_poly b) j') + (Spec.zeta_at (64 + 4 * j')) (Spec.zeta_at (64 + 4 * j' + 1)) + (Spec.zeta_at (64 + 4 * j' + 2)) (Spec.zeta_at (64 + 4 * j' + 3))) + )[j / 16]!).val[j % 16]!))[ℓ]! = _ + rw [getElem!_pos _ ℓ (by simp [List.length_map, List.length_range, hℓ])] + rw [List.getElem_map, List.getElem_range] + -- Now we have: `(((List.range 16).map f)[ℓ / 16]!).val[ℓ % 16]! = RHS`. + -- Reduce the inner getElem!. + rw [getElem!_pos _ (ℓ / 16) (by simp [List.length_map, List.length_range, h_div_lt])] + rw [List.getElem_map, List.getElem_range] + -- Now apply Helper 2. + exact Stage3MontStripFC.ntt_multiply_pure_no_acc_chunk_at_lift_poly_eq + a b (ℓ / 16) h_div_lt + (Spec.zeta_at (64 + 4 * (ℓ / 16))) (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3)) + (ℓ % 16) h_mod_lt + +set_option maxHeartbeats 1000000 in +/-- **Helper 2.** Generic foldl/mul distributivity: pulling a uniform + right-multiplicand `K` out of every accumulator step. + + `mul_pure (foldl (fun s x => add_pure s (f x)) seed L) K + = foldl (fun s x => add_pure s (mul_pure (f x) K)) (mul_pure seed K) L`. -/ +theorem foldl_add_mul_distrib + {α : Type} (L : List α) + (f : α → hacspec_ml_kem.parameters.FieldElement) + (seed K : hacspec_ml_kem.parameters.FieldElement) : + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (L.foldl (fun s x => libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + s (f x)) seed) K + = L.foldl (fun s x => libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + s + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure (f x) K)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure seed K) := by + -- Key fact: `mul_pure (add_pure a b) K = add_pure (mul_pure a K) (mul_pure b K)`. + -- We prove this via ZMod 3329 projection + canonical round-trip. + have h_distrib : + ∀ a b K' : hacspec_ml_kem.parameters.FieldElement, + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b) K' + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a K') + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure b K') := by + intro a b K' + have h_canon_lhs : libcrux_iot_ml_kem.Spec.Pure.Canonical + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b) K') := + libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + have h_canon_rhs : libcrux_iot_ml_kem.Spec.Pure.Canonical + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a K') + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure b K')) := + libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + have h_canon_to_lt : ∀ x : hacspec_ml_kem.parameters.FieldElement, + libcrux_iot_ml_kem.Spec.Pure.Canonical x → x.val.val < 3329 := by + intro x hx + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at hx + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at hx + exact hx + have h_lt_lhs := h_canon_to_lt _ h_canon_lhs + have h_lt_rhs := h_canon_to_lt _ h_canon_rhs + rw [← feOfZMod_zmodOfFE_of_canonical _ h_lt_lhs, + ← feOfZMod_zmodOfFE_of_canonical _ h_lt_rhs] + congr 1 + simp only [L2_8c.zmodOfFE_add_pure, L2_8c.zmodOfFE_mul_pure] + ring + -- Now induction on L. We need an aux that handles the changing seed. + induction L generalizing seed with + | nil => simp + | cons h t ih => + simp only [List.foldl_cons] + have h_step := ih + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure seed (f h)) + rw [h_step] + -- Goal: foldl ... (mul_pure (add_pure seed (f h)) K) + -- = foldl ... (add_pure (mul_pure seed K) (mul_pure (f h) K)) + rw [h_distrib seed (f h) K] + +/-- Canonical row-sum at column step k (foldl form in `lift_chunk_mont`, + canonical post-scale via `mul_pure 1353 1353`). This is the partial-sum + value at lane ℓ that the column loop produces at step k. -/ +noncomputable def col_loop_lane_at_step + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (i : Nat) (k : Nat) (ℓ : Nat) : + hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((List.range k).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[i * K.val + c]!.coefficients.val[ℓ / 16]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[ℓ / 16]!)) + (Spec.zeta_at (64 + 4 * (ℓ / 16))) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3))).val[ℓ % 16]!)) + (Spec.mont_reduce_pure (lift_fe_int 0))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (1353#i16 : Std.I16)) + (lift_fe_mont (1353#i16 : Std.I16))) + +/-- The expected result of `multiply_matrix_by_column_at` at step k: + an `Array FE 256` whose lane ℓ equals `col_loop_lane_at_step ... k ℓ`. -/ +noncomputable def col_loop_result_at_step + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (i : Nat) (k : Nat) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + ⟨(List.range 256).map (fun ℓ => col_loop_lane_at_step matrix_A s_as_ntt i k ℓ), + by simp [List.length_map, List.length_range]⟩ + +/-- The lane-ℓ equation of `col_loop_result_at_step`. -/ +theorem col_loop_result_at_step_val_lane + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (i : Nat) (k : Nat) (ℓ : Nat) (hℓ : ℓ < 256) : + (col_loop_result_at_step matrix_A s_as_ntt i k).val[ℓ]! + = col_loop_lane_at_step matrix_A s_as_ntt i k ℓ := by + unfold col_loop_result_at_step + show ((List.range 256).map (fun ℓ' => col_loop_lane_at_step matrix_A s_as_ntt i k ℓ'))[ℓ]! = _ + rw [getElem!_pos _ ℓ (by simp [List.length_map, List.length_range, hℓ])] + rw [List.getElem_map, List.getElem_range] + +/-- **Helper 3a.** Base-case lane equation: at step k=0, every lane of + `col_loop_result_at_step` equals `parameters.FieldElement.new 0`. -/ +theorem col_loop_lane_at_step_zero + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (i : Nat) (ℓ : Nat) : + col_loop_lane_at_step matrix_A s_as_ntt i 0 ℓ + = ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement) := by + unfold col_loop_lane_at_step + rw [List.range_zero, List.foldl_nil] + -- Now: mul_pure (mont_reduce_pure (lift_fe_int 0)) (mul_pure 1353 1353) = ⟨0#u16⟩. + -- Both sides have ZMod 3329 value 0, both are canonical. + have h_canon_lhs : libcrux_iot_ml_kem.Spec.Pure.Canonical + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Spec.mont_reduce_pure (lift_fe_int 0)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (1353#i16 : Std.I16)) (lift_fe_mont (1353#i16 : Std.I16)))) := + libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + have h_canon_rhs : libcrux_iot_ml_kem.Spec.Pure.Canonical + ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement) := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] + decide + have h_canon_to_lt : ∀ x : hacspec_ml_kem.parameters.FieldElement, + libcrux_iot_ml_kem.Spec.Pure.Canonical x → x.val.val < 3329 := by + intro x hx + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at hx + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at hx; exact hx + have h_lt_lhs := h_canon_to_lt _ h_canon_lhs + have h_lt_rhs := h_canon_to_lt _ h_canon_rhs + rw [← feOfZMod_zmodOfFE_of_canonical _ h_lt_lhs, + ← feOfZMod_zmodOfFE_of_canonical _ h_lt_rhs] + congr 1 + rw [L2_8c.zmodOfFE_mul_pure] + unfold Spec.mont_reduce_pure + rw [zmodOfFE_feOfZMod] + unfold lift_fe_int + rw [zmodOfFE_feOfZMod] + -- LHS: 0 * 169 * 169 * (zmodOfFE ...) = 0; RHS: zmodOfFE ⟨0#u16⟩ = 0. + unfold zmodOfFE + simp + +set_option maxHeartbeats 4000000 in +/-- **Helper 3b.** Step lemma: at step k < K, taking one column iteration + transforms `col_loop_result_at_step ... k` into `col_loop_result_at_step ... (k+1)`. -/ +theorem col_loop_lane_at_step_succ + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (i : Nat) (k : Nat) (ℓ : Nat) (hℓ : ℓ < 256) : + col_loop_lane_at_step matrix_A s_as_ntt i (k + 1) ℓ + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (col_loop_lane_at_step matrix_A s_as_ntt i k ℓ) + (Spec.multiply_ntts_pure + (lift_poly matrix_A.val[i * K.val + k]!) (lift_poly s_as_ntt.val[k]!)).val[ℓ]! := by + unfold col_loop_lane_at_step + -- LHS: mul_pure (foldl_{k+1}) (mul_pure 1353 1353). + rw [List.range_succ, List.foldl_append, List.foldl_cons, List.foldl_nil] + -- Now LHS = mul_pure (add_pure (foldl_k) (no_acc_lane_at_k)) (mul_pure 1353 1353). + -- Distribute via h_distrib (essentially Helper 2's per-pair fact, inlined). + -- Specifically: mul_pure (add_pure x y) z = add_pure (mul_pure x z) (mul_pure y z). + have h_distrib : + ∀ a b c : hacspec_ml_kem.parameters.FieldElement, + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b) c + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a c) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure b c) := by + intro a b c + have h_canon_lhs : libcrux_iot_ml_kem.Spec.Pure.Canonical + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b) c) := + libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + have h_canon_rhs : libcrux_iot_ml_kem.Spec.Pure.Canonical + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a c) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure b c)) := + libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + have h_canon_to_lt : ∀ x : hacspec_ml_kem.parameters.FieldElement, + libcrux_iot_ml_kem.Spec.Pure.Canonical x → x.val.val < 3329 := by + intro x hx + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at hx + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at hx; exact hx + rw [← feOfZMod_zmodOfFE_of_canonical _ (h_canon_to_lt _ h_canon_lhs), + ← feOfZMod_zmodOfFE_of_canonical _ (h_canon_to_lt _ h_canon_rhs)] + congr 1 + simp only [L2_8c.zmodOfFE_add_pure, L2_8c.zmodOfFE_mul_pure] + ring + rw [h_distrib] + -- Now LHS = add_pure (mul_pure (foldl_k) (mul_pure 1353 1353)) (mul_pure (no_acc_lane_at_k) (mul_pure 1353 1353)) + -- The first summand is `col_loop_lane_at_step ... k ℓ` (after unfold). + -- The second summand should equal `(Spec.multiply_ntts_pure (lift_poly matrix_A.val[i*K+k]!) (lift_poly s_as_ntt.val[k]!)).val[ℓ]!` + -- via Helper 1. + congr 1 + -- Apply Helper 1. + rw [multiply_ntts_lane_eq_canonical_factor _ _ ℓ hℓ] + +/-- **Helper 3c.** Closing equation: `col_loop_lane_at_step ... K.val ℓ = mul_pure (canonical_row_sum_lane ...) (lift_fe_mont 1353)`. -/ +theorem col_loop_lane_at_step_K_eq_canonical + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (i : Nat) (ℓ : Nat) : + col_loop_lane_at_step matrix_A s_as_ntt i K.val ℓ + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt i (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16)) := by + unfold col_loop_lane_at_step Stage3MontStripFC.canonical_row_sum_lane + -- LHS: mul_pure (foldl ...) (mul_pure 1353 1353). + -- RHS: mul_pure (mul_pure (foldl ...) 1353) 1353. + -- Same foldl on both sides; associativity of mul_pure is the only step. + -- Use canonical round-trip + ring. + have h_canon_to_lt : ∀ x : hacspec_ml_kem.parameters.FieldElement, + libcrux_iot_ml_kem.Spec.Pure.Canonical x → x.val.val < 3329 := by + intro x hx + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at hx + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at hx; exact hx + -- Apply the canonical-rewrite via the result of mul_pure being canonical. + set foldl_sum := (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (matrix_A.val[i * K.val + c]!.coefficients.val[ℓ / 16]!)) + (lift_chunk_mont (s_as_ntt.val[c]!.coefficients.val[ℓ / 16]!)) + (Spec.zeta_at (64 + 4 * (ℓ / 16))) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3))).val[ℓ % 16]!)) + (Spec.mont_reduce_pure (lift_fe_int 0)) with h_fs_def + set mont1353 := lift_fe_mont (1353#i16 : Std.I16) + -- Both LHS and RHS are products of `mul_pure`, hence canonical. + have h_canon_lhs := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure + foldl_sum + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure mont1353 mont1353) + have h_canon_rhs := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure foldl_sum mont1353) + mont1353 + rw [← feOfZMod_zmodOfFE_of_canonical _ (h_canon_to_lt _ h_canon_lhs), + ← feOfZMod_zmodOfFE_of_canonical _ (h_canon_to_lt _ h_canon_rhs)] + apply congrArg + simp only [L2_8c.zmodOfFE_mul_pure] + ring + +set_option maxHeartbeats 16000000 in +set_option maxRecDepth 1000 in +theorem multiply_matrix_by_column_at_eq + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) + (i : Std.Usize) (hi : i.val < K.val) : + hacspec_ml_kem.matrix.multiply_matrix_by_column_at + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) i + = .ok ⟨(List.range 256).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt i.val + (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))), + by simp [List.length_map, List.length_range]⟩ := by + -- Reduce the target: the .ok's payload coincides with `col_loop_result_at_step ... K.val` + -- after applying `col_loop_lane_at_step_K_eq_canonical`. + have h_target_eq : + (⟨(List.range 256).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt i.val + (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))), + by simp [List.length_map, List.length_range]⟩ : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) = + col_loop_result_at_step matrix_A s_as_ntt i.val K.val := by + apply Subtype.ext + unfold col_loop_result_at_step + show _ = (List.range 256).map _ + apply List.map_congr_left + intro ℓ _ + rw [col_loop_lane_at_step_K_eq_canonical] + rw [h_target_eq] + -- Now we need: multiply_matrix_by_column_at lift_M lift_S i = .ok (col_loop_result_at_step ... K.val). + -- Use the loop equation directly. + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at + unfold hacspec_ml_kem.parameters.FieldElement.new + simp only [bind_tc_ok] + -- Now goal: multiply_matrix_by_column_at_loop ⟨0, K⟩ lift_M lift_S i (Array.repeat 256 ⟨0⟩) = .ok ... + -- Use loop_range_spec_usize with `inv k r = pure (r = col_loop_result_at_step ... k.val)`. + -- Step 1: get the Triple form, then extract the .ok form. + have h_triple : ⦃ ⌜ True ⌝ ⦄ + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop + ({ start := 0#usize, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize) + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) i + (Std.Array.repeat (256#usize : Std.Usize) + ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement)) + ⦃ ⇓ r => ⌜ r = col_loop_result_at_step matrix_A s_as_ntt i.val K.val ⌝ ⦄ := by + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize => + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) i p.1 p.2) + (β := Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (Std.Array.repeat (256#usize : Std.Usize) + ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement)) + 0#usize K + (fun k result => pure (result = col_loop_result_at_step matrix_A s_as_ntt i.val k.val)) + (Nat.zero_le _) + (by + -- Base: init = col_loop_result_at_step ... 0. + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + apply Subtype.ext + rw [Std.Array.repeat_val] + unfold col_loop_result_at_step + show List.replicate 256 _ = (List.range 256).map _ + apply List.ext_getElem + · rw [List.length_replicate, List.length_map, List.length_range] + intro n h_n_lhs _ + have h_n_lt : n < 256 := by + rw [List.length_replicate] at h_n_lhs; exact h_n_lhs + rw [List.getElem_replicate, List.getElem_map, List.getElem_range] + show _ = col_loop_lane_at_step matrix_A s_as_ntt i.val 0 n + rw [col_loop_lane_at_step_zero]) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r hh + have h_eq : (pure (r = col_loop_result_at_step matrix_A s_as_ntt i.val K.val) + : Result Prop).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_eq + · -- Step. + intro acc k h_ge h_le hinv + have h_acc_eq : acc = col_loop_result_at_step matrix_A s_as_ntt i.val k.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hinv + subst h_acc_eq + -- Body: Range.next; if Some j (j < K), then index_usize + multiply_ntts + add_polynomials. + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + by_cases h_lt : k.val < K.val + · -- `Some k` branch. + have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, + ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- Compute (lift_matrix_from_slice matrix_A K).val[k.val]! — the k-th column. + have h_lift_M_len : (lift_matrix_from_slice matrix_A K).length = K.val := + Std.Array.length_eq _ + set col_k : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + (lift_matrix_from_slice matrix_A K).val[k.val]! with h_col_k_def + have h_idx_col : Aeneas.Std.Array.index_usize (lift_matrix_from_slice matrix_A K) k + = .ok col_k := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq _ k + (by rw [h_lift_M_len]; exact h_lt) + -- The k-th column at row i is lift_poly matrix_A.val[i.val * K.val + k.val]!. + have h_col_k_val : col_k.val[i.val]! + = lift_poly matrix_A.val[i.val * K.val + k.val]! := by + rw [h_col_k_def] + unfold lift_matrix_from_slice + show ((List.range K.val).map (fun j' => + Std.Array.make K + ((List.range K.val).map (fun i' => + lift_poly matrix_A.val[i' * K.val + j']!)) + (by simp)))[k.val]!.val[i.val]! = _ + rw [getElem!_pos _ k.val (by simp [List.length_map, List.length_range]; exact h_lt)] + rw [List.getElem_map, List.getElem_range] + show ((List.range K.val).map (fun i' => + lift_poly matrix_A.val[i' * K.val + k.val]!))[i.val]! = _ + rw [getElem!_pos _ i.val (by simp [List.length_map, List.length_range]; exact hi)] + rw [List.getElem_map, List.getElem_range] + have h_col_k_len : col_k.length = K.val := Std.Array.length_eq _ + set a1 : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + lift_poly matrix_A.val[i.val * K.val + k.val]! with h_a1_def + have h_idx_a1 : Aeneas.Std.Array.index_usize col_k i = .ok a1 := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq col_k i + (by rw [h_col_k_len]; exact hi) + rw [h_col_k_val] at this; exact this + -- Compute (lift_vec s_as_ntt).val[k.val]! = lift_poly s_as_ntt.val[k.val]!. + have h_lift_S_len : (lift_vec s_as_ntt).length = K.val := Std.Array.length_eq _ + set a2 : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + lift_poly s_as_ntt.val[k.val]! with h_a2_def + have h_lift_S_val : (lift_vec s_as_ntt).val[k.val]! = a2 := by + rw [h_a2_def] + unfold lift_vec + show (s_as_ntt.val.map lift_poly)[k.val]! = _ + have h_len_s : s_as_ntt.val.length = K.val := Std.Array.length_eq _ + rw [getElem!_pos _ k.val (by rw [List.length_map, h_len_s]; exact h_lt)] + rw [List.getElem_map] + rw [show s_as_ntt.val[k.val] = s_as_ntt.val[k.val]! from + (getElem!_pos _ k.val (by rw [h_len_s]; exact h_lt)).symm] + have h_idx_a2 : Aeneas.Std.Array.index_usize (lift_vec s_as_ntt) k = .ok a2 := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq (lift_vec s_as_ntt) k + (by rw [h_lift_S_len]; exact h_lt) + rw [h_lift_S_val] at this; exact this + -- multiply_ntts a1 a2 = .ok (Spec.multiply_ntts_pure a1 a2). + have h_mult_eq : hacspec_ml_kem.ntt.multiply_ntts a1 a2 + = .ok (Spec.multiply_ntts_pure a1 a2) := by + unfold Spec.multiply_ntts_pure + rw [HelpersFC.multiply_ntts_eq_pure_array] + -- add_polynomials previous product = .ok new_acc. + have h_add_eq := Stage4MatrixAddFC.matrix_add_polynomials_eq_ok + (col_loop_result_at_step matrix_A s_as_ntt i.val k.val) + (Spec.multiply_ntts_pure a1 a2) + -- Show the new accumulator equals col_loop_result_at_step ... (k.val + 1). + set new_acc : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + ⟨(List.range 256).map (fun n => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (col_loop_result_at_step matrix_A s_as_ntt i.val k.val).val[n]! + (Spec.multiply_ntts_pure a1 a2).val[n]!), + by simp [List.length_map, List.length_range]⟩ with h_new_acc_def + have h_new_acc_eq : new_acc + = col_loop_result_at_step matrix_A s_as_ntt i.val (k.val + 1) := by + unfold col_loop_result_at_step + apply Subtype.ext + rw [h_new_acc_def] + apply List.map_congr_left + intro n hn_mem + have hn_lt : n < 256 := List.mem_range.mp hn_mem + rw [col_loop_result_at_step_val_lane _ _ _ _ _ hn_lt] + rw [col_loop_lane_at_step_succ _ _ _ _ _ hn_lt] + -- Body equation: drive the do-block to .ok (cont ...). + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize => + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) i p.1 p.2) + ({ start := k, «end» := K }, + col_loop_result_at_step matrix_A s_as_ntt i.val k.val) + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + col_loop_result_at_step matrix_A s_as_ntt i.val (k.val + 1))) := by + show hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) i + { start := k, «end» := K } + (col_loop_result_at_step matrix_A s_as_ntt i.val k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + -- Force destructure (some k, iter') into the some j=k branch. + show ((do + let a ← Aeneas.Std.Array.index_usize + (lift_matrix_from_slice matrix_A K) k + let a1' ← Aeneas.Std.Array.index_usize a i + let a2' ← Aeneas.Std.Array.index_usize (lift_vec s_as_ntt) k + let product ← hacspec_ml_kem.ntt.multiply_ntts a1' a2' + let result1 ← hacspec_ml_kem.matrix.add_polynomials + (col_loop_result_at_step matrix_A s_as_ntt i.val k.val) product + Aeneas.Std.Result.ok (ControlFlow.cont + (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), result1))) + : Result _) = _ + rw [h_idx_col] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a1] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a2] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_mult_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_add_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [← h_new_acc_eq] + apply triple_of_ok_fc h_body + -- Discharge step_post for .cont branch. + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + show (pure (col_loop_result_at_step matrix_A s_as_ntt i.val (k.val + 1) + = col_loop_result_at_step matrix_A s_as_ntt i.val s_iter.val) + : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hs_iter_val] + rfl + · -- `None` branch: k ≥ K, body returns .ok (.done result). + have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize => + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) i p.1 p.2) + ({ start := k, «end» := K }, + col_loop_result_at_step matrix_A s_as_ntt i.val k.val) + = .ok (ControlFlow.done + (col_loop_result_at_step matrix_A s_as_ntt i.val k.val)) := by + show hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) i + { start := k, «end» := K } + (col_loop_result_at_step matrix_A s_as_ntt i.val k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show (pure (col_loop_result_at_step matrix_A s_as_ntt i.val k.val + = col_loop_result_at_step matrix_A s_as_ntt i.val K.val) + : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hk_eq] + rfl + -- Now extract the .ok form from h_triple. + match h_loop : hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop + ({ start := 0#usize, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize) + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) i + (Std.Array.repeat (256#usize : Std.Usize) + ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement)), + h_triple with + | .ok r, h => + have hr : r = col_loop_result_at_step matrix_A s_as_ntt i.val K.val := by + simpa [Std.Do.Triple, Std.Do.WP.wp] using h + rw [hr] + | .fail _, h => exact absurd h (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div, h => exact absurd h (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + +set_option maxHeartbeats 4000000 in +/-- **Helper 4.** Outer `createi K` wrapper around Helper 3 producing the + full `multiply_matrix_by_column` result. -/ +theorem multiply_matrix_by_column_eq + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) : + hacspec_ml_kem.matrix.multiply_matrix_by_column + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) + = .ok ⟨(List.range K.val).map (fun i => + (⟨(List.range 256).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt i + (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))), + by simp [List.length_map, List.length_range]⟩ : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize)), + by simp [List.length_map, List.length_range]⟩ := by + set f : Nat → Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + fun i => ⟨(List.range 256).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt i + (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))), + by simp [List.length_map, List.length_range]⟩ with hf_def + have hpure : ∀ k : Nat, k < K.val → + (hacspec_ml_kem.matrix.multiply_matrix_by_column.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256 + K).FnMutInst.call_mut + (lift_matrix_from_slice matrix_A K, lift_vec s_as_ntt) ⟨BitVec.ofNat _ k⟩ + = .ok (f k, (lift_matrix_from_slice matrix_A K, lift_vec s_as_ntt)) := by + intro k hk + show hacspec_ml_kem.matrix.multiply_matrix_by_column.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + (lift_matrix_from_slice matrix_A K, lift_vec s_as_ntt) ⟨BitVec.ofNat _ k⟩ + = .ok (f k, (lift_matrix_from_slice matrix_A K, lift_vec s_as_ntt)) + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256.call + -- Reduce to: multiply_matrix_by_column_at lift_M lift_S ⟨k⟩. + have hk_val : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + have hK_lt : K.val < 2^System.Platform.numBits := by + have h := K.hBounds + simp [] at h + omega + exact Nat.lt_of_lt_of_le hk (Nat.le_of_lt hK_lt) + -- Need: multiply_matrix_by_column_at lift_M lift_S ⟨k⟩ = .ok (f k). + have hk_lt : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < K.val := by + rw [hk_val]; exact hk + have h_mmbc_at := multiply_matrix_by_column_at_eq matrix_A s_as_ntt hAlen + (⟨BitVec.ofNat _ k⟩ : Std.Usize) hk_lt + show (do let a ← hacspec_ml_kem.matrix.multiply_matrix_by_column_at + (lift_matrix_from_slice matrix_A K) (lift_vec s_as_ntt) + (⟨BitVec.ofNat _ k⟩ : Std.Usize) + .ok (a, (lift_matrix_from_slice matrix_A K, lift_vec s_as_ntt))) = + .ok (f k, _) + rw [h_mmbc_at]; simp only [bind_tc_ok] + rw [hk_val] + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column + unfold hacspec_ml_kem.parameters.createi + have h_from_fn := libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq + (T := Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (F := hacspec_ml_kem.matrix.multiply_matrix_by_column.closure K) + (N := K) + (inst := (hacspec_ml_kem.matrix.multiply_matrix_by_column.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256 K).FnMutInst) + (c := (lift_matrix_from_slice matrix_A K, lift_vec s_as_ntt)) + (f := f) + hpure + show core.array.from_fn K _ _ = _ + exact h_from_fn + +set_option maxHeartbeats 32000000 in +/-- **Helper 5 (main bridge).** Given the per-row, per-lane characterization + of `t_as_ntt_final`, proves the hacspec `compute_As_plus_e` equation that + L7.1's POST demands. -/ +theorem hacspec_compute_As_plus_e_eq_of_lane_eq + {K : Std.Usize} + (matrix_A : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt error_as_ntt t_as_ntt_final : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) + (h_lane_eq : ∀ r : Nat, r < K.val → ∀ ℓ : Nat, ℓ < 256 → + (lift_poly t_as_ntt_final.val[r]!).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt r (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly error_as_ntt.val[r]!).val[ℓ]!)) : + hacspec_ml_kem.matrix.compute_As_plus_e + (lift_matrix_from_slice matrix_A K) + (lift_vec s_as_ntt) (lift_vec error_as_ntt) + = .ok (lift_vec t_as_ntt_final) := by + unfold hacspec_ml_kem.matrix.compute_As_plus_e + -- Step 1: replace `multiply_matrix_by_column ... = .ok (P_arr)` via Helper 4. + rw [multiply_matrix_by_column_eq matrix_A s_as_ntt hAlen] + simp only [bind_tc_ok] + -- Step 2: unfold add_vectors and apply from_fn_pure_eq. + set P_arr : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + ⟨(List.range K.val).map (fun i => + (⟨(List.range 256).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt i + (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))), + by simp [List.length_map, List.length_range]⟩ : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize)), + by simp [List.length_map, List.length_range]⟩ with hP_def + -- For lookup at row r < K.val. + have h_P_at : ∀ r : Nat, r < K.val → + P_arr.val[r]! = (⟨(List.range 256).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt r + (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))), + by simp [List.length_map, List.length_range]⟩ : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) := by + intro r hr + rw [hP_def] + show ((List.range K.val).map _)[r]! = _ + rw [getElem!_pos _ r (by simp [List.length_map, List.length_range, hr])] + rw [List.getElem_map, List.getElem_range] + -- For lane lookup inside P_arr[r]. + have h_P_at_lane : ∀ r : Nat, r < K.val → ∀ ℓ : Nat, ℓ < 256 → + (P_arr.val[r]!).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt r + (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16)) := by + intro r hr ℓ hℓ + rw [h_P_at r hr] + show ((List.range 256).map (fun ℓ' => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt r + (ℓ' / 16) (ℓ' % 16)) + (lift_fe_mont (1353#i16 : Std.I16))))[ℓ]! = _ + rw [getElem!_pos _ ℓ (by simp [List.length_map, List.length_range, hℓ])] + rw [List.getElem_map, List.getElem_range] + -- Now: hacspec_ml_kem.matrix.add_vectors P_arr (lift_vec error_as_ntt) = .ok (lift_vec t_as_ntt_final). + unfold hacspec_ml_kem.matrix.add_vectors + unfold hacspec_ml_kem.parameters.createi + set f_out : Nat → Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + fun r => (lift_vec t_as_ntt_final).val[r]! with hf_def + -- Per-call_mut equation. + have hpure : ∀ r : Nat, r < K.val → + (hacspec_ml_kem.matrix.add_vectors.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256 K).FnMutInst.call_mut + (P_arr, lift_vec error_as_ntt) ⟨BitVec.ofNat _ r⟩ + = .ok (f_out r, (P_arr, lift_vec error_as_ntt)) := by + intro r hr + show hacspec_ml_kem.matrix.add_vectors.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + (P_arr, lift_vec error_as_ntt) ⟨BitVec.ofNat _ r⟩ + = .ok (f_out r, (P_arr, lift_vec error_as_ntt)) + unfold hacspec_ml_kem.matrix.add_vectors.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + unfold hacspec_ml_kem.matrix.add_vectors.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256.call + have hr_val : (⟨BitVec.ofNat _ r⟩ : Std.Usize).val = r := by + show (BitVec.ofNat _ r).toNat = r + apply Nat.mod_eq_of_lt + have hK_lt : K.val < 2^System.Platform.numBits := by + have h := K.hBounds + simp [] at h + omega + exact Nat.lt_of_lt_of_le hr (Nat.le_of_lt hK_lt) + have hP_len : P_arr.length = K.val := by + rw [hP_def] + show ((List.range K.val).map _).length = K.val + simp [List.length_map, List.length_range] + have hE_len : (lift_vec error_as_ntt).length = K.val := by + unfold lift_vec + show (error_as_ntt.val.map lift_poly).length = K.val + rw [List.length_map, error_as_ntt.property] + have hr_lt_P : (⟨BitVec.ofNat _ r⟩ : Std.Usize).val < P_arr.length := by + rw [hr_val, hP_len]; exact hr + have hr_lt_E : (⟨BitVec.ofNat _ r⟩ : Std.Usize).val < (lift_vec error_as_ntt).length := by + rw [hr_val, hE_len]; exact hr + have h_idx_P : Std.Array.index_usize P_arr (⟨BitVec.ofNat _ r⟩ : Std.Usize) + = .ok (P_arr.val[r]!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq P_arr + (⟨BitVec.ofNat _ r⟩ : Std.Usize) hr_lt_P + rw [hr_val] at this; exact this + have h_idx_E : Std.Array.index_usize (lift_vec error_as_ntt) (⟨BitVec.ofNat _ r⟩ : Std.Usize) + = .ok ((lift_vec error_as_ntt).val[r]!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq (lift_vec error_as_ntt) + (⟨BitVec.ofNat _ r⟩ : Std.Usize) hr_lt_E + rw [hr_val] at this; exact this + -- Apply matrix_add_polynomials_eq_ok. + have h_add := matrix_add_polynomials_eq_ok + (P_arr.val[r]!) ((lift_vec error_as_ntt).val[r]!) + -- Reduce: the closure destructures (a, a1) := (P_arr, lift_vec error_as_ntt). + -- Use `show` to expose the destructure outcome (a → P_arr, a1 → lift_vec error_as_ntt). + show (do let fe ← (do + let a2 ← Std.Array.index_usize P_arr (⟨BitVec.ofNat _ r⟩ : Std.Usize) + let a3 ← Std.Array.index_usize (lift_vec error_as_ntt) (⟨BitVec.ofNat _ r⟩ : Std.Usize) + hacspec_ml_kem.matrix.add_polynomials a2 a3) + .ok (fe, P_arr, lift_vec error_as_ntt)) = _ + rw [h_idx_P]; simp only [bind_tc_ok] + rw [h_idx_E]; simp only [bind_tc_ok] + rw [h_add]; simp only [bind_tc_ok] + -- Now need: ok (⟨...add_pure ⟩, P, E) = ok (f_out r, P, E). + -- Beta-reduce f_out r. + show Result.ok (⟨List.map _ (List.range 256), _⟩, P_arr, lift_vec error_as_ntt) = + Result.ok ((lift_vec t_as_ntt_final).val[r]!, P_arr, lift_vec error_as_ntt) + have h_lift_t_at : (lift_vec t_as_ntt_final).val[r]! = lift_poly t_as_ntt_final.val[r]! := by + unfold lift_vec + show (t_as_ntt_final.val.map lift_poly)[r]! = _ + rw [getElem!_pos _ r (by rw [List.length_map, t_as_ntt_final.property]; exact hr)] + rw [List.getElem_map] + congr 1 + rw [getElem!_pos _ r (by rw [t_as_ntt_final.property]; exact hr)] + rw [h_lift_t_at] + congr 1 + -- Now: (⟨...⟩, P, E) = (lift_poly t.val[r]!, P, E). Reduce the tuple. + refine Prod.mk.injEq _ _ _ _ |>.mpr ⟨?_, rfl⟩ + -- Now we need: ⟨(List.range 256).map _, _⟩ = lift_poly t_as_ntt_final.val[r]!. + -- Apply per-lane equality via h_lane_eq. + -- Goal: ⟨List.map (fun k => add_pure ...), _⟩ = lift_poly t_as_ntt_final.val[r]! + -- Both sides are Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize. + -- Use Subtype.ext + manual val-level rewrites via rfl-shape lemmas. + have h_coe_mk : ∀ (l : List hacspec_ml_kem.parameters.FieldElement) + (h : l.length = (256#usize : Std.Usize).val), + ((⟨l, h⟩ : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : List _) = l := + fun _ _ => rfl + have h_make : ∀ (l : List hacspec_ml_kem.parameters.FieldElement) + (h : l.length = (256#usize : Std.Usize).val), + ((Std.Array.make 256#usize l h : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : List _) = l := + fun _ _ => rfl + apply Subtype.ext + unfold lift_poly + rw [h_coe_mk, h_make] + apply List.map_congr_left + intro ℓ hℓ_mem + have hℓ : ℓ < 256 := List.mem_range.mp hℓ_mem + have h_lane := h_lane_eq r hr ℓ hℓ + rw [h_P_at_lane r hr ℓ hℓ] + have h_lift_e_at : (lift_vec error_as_ntt).val[r]! = lift_poly error_as_ntt.val[r]! := by + unfold lift_vec + show (error_as_ntt.val.map lift_poly)[r]! = _ + rw [getElem!_pos _ r (by rw [List.length_map, error_as_ntt.property]; exact hr)] + rw [List.getElem_map] + congr 1 + rw [getElem!_pos _ r (by rw [error_as_ntt.property]; exact hr)] + rw [h_lift_e_at] + rw [← h_lane] + -- Goal: (lift_poly t_as_ntt_final.val[r]!).val[ℓ]! = lift_fe (...). + -- After unfolding lift_poly the LHS reduces to ((List.range 256).map ...)[ℓ]! + -- which rewrites to lift_fe (...) via getElem!_pos + getElem_map + getElem_range. + unfold lift_poly + show (((List.range 256).map + (fun j => lift_fe + ((t_as_ntt_final.val[r]!.coefficients.val[j / 16]!).elements.val[j % 16]!)))[ℓ]! + : hacspec_ml_kem.parameters.FieldElement) = + lift_fe ((t_as_ntt_final.val[r]!.coefficients.val[ℓ / 16]!).elements.val[ℓ % 16]!) + rw [getElem!_pos _ ℓ (by simp [List.length_map, List.length_range, hℓ])] + rw [List.getElem_map, List.getElem_range] + have h_from_fn := libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq + (T := Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (F := hacspec_ml_kem.matrix.add_vectors.closure K) + (N := K) + (inst := (hacspec_ml_kem.matrix.add_vectors.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256 K).FnMutInst) + (c := (P_arr, lift_vec error_as_ntt)) + (f := f_out) + hpure + show core.array.from_fn K _ _ = _ + rw [h_from_fn] + -- Now need: .ok ⟨(List.range K.val).map f_out, _⟩ = .ok (lift_vec t_as_ntt_final). + -- f_out r = (lift_vec t_as_ntt_final).val[r]!. So the LHS is the + -- array whose `.val[r] = (lift_vec t_as_ntt_final).val[r]!` for r < K.val. + congr 1 + apply Subtype.ext + unfold lift_vec + change (List.range K.val).map f_out = t_as_ntt_final.val.map lift_poly + apply List.ext_getElem + · simp [List.length_map, List.length_range, t_as_ntt_final.property] + intro n h_n_lhs _ + have h_n_lt : n < K.val := by + rw [List.length_map, List.length_range] at h_n_lhs; exact h_n_lhs + rw [List.getElem_map, List.getElem_range] + rw [hf_def] + -- f_out n = (lift_vec t_as_ntt_final).val[n]! = lift_poly t_as_ntt_final.val[n]!. + unfold lift_vec + show (t_as_ntt_final.val.map lift_poly)[n]! = (t_as_ntt_final.val.map lift_poly)[n] + rw [getElem!_pos _ n (by rw [List.length_map, t_as_ntt_final.property]; exact h_n_lt)] + +end Stage4MatrixAddFC + +/-! ## §L7 — matrix-level targets (4 theorems). + + These are the ultimate FC obligations: the impl matrix functions + must compute the same hacspec ring-element vector / single ring + element as their spec counterparts. -/ + +/-- L7.1 — `matrix.compute_As_plus_e`: product `A · s + e` of the + public-key generation step. Impl returns + `(t_as_ntt, s_cache, accumulator)`; project on `t_as_ntt`. + + PRE: + - `hAlen` : flat slice has K·K entries. + - `hK` : `K.val ≤ 4` (ML-KEM 768/1024 etc.; drives `K · 2^25 ≤ 2^27` + bound for `poly_reducing_from_i32_array_fc`). + - `h_matrix_bnd` : per-lane bound on `matrix_A`'s entries + (consumed by L6.3c `accumulating_ntt_multiply_*_poly_fc`). + - `h_s_bnd` : per-lane bound on `s_as_ntt`'s entries. + - `h_error_bnd` : per-lane bound on `error_as_ntt`'s entries + (the 29439 = 9 · 3271 ceiling required by L6.5 + `add_standard_error_reduce_fc`). + - `h_acc_bnd` : per-lane additive-budget bound on initial + `accumulator` (consumed by row-0 forward dep `Stage 1`, whose PRE + requires `acc[n] + K · 2^25 ≤ 2^30`; rows 1..K-1 re-zero the + accumulator inside `compute_As_plus_e_loop1`). -/ +@[spec] +theorem compute_As_plus_e_fc + {K : Std.Usize} + (t_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (matrix_A : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (s_as_ntt error_as_ntt s_cache : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (accumulator : Std.Array Std.I32 256#usize) + (hAlen : matrix_A.length = (K.val * K.val : Nat)) + (hK : K.val ≤ 4) + (h_matrix_bnd : ∀ k : Fin matrix_A.length, ∀ i j : Fin 16, + ((matrix_A.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_s_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((s_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_error_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((error_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 29439) + (h_acc_bnd : ∀ n : Fin 256, + (accumulator.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) + (h_acc_zero : ∀ n : Nat, n < 256 → accumulator.val[n]! = (0#i32 : Std.I32)) + (hK_pos : 0 < K.val) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_As_plus_e + (vectortraitsOperationsInst := portable_ops_inst) + t_as_ntt matrix_A s_as_ntt error_as_ntt s_cache accumulator + ⦃ ⇓ p => ⌜ hacspec_ml_kem.matrix.compute_As_plus_e + (lift_matrix_from_slice matrix_A K) + (lift_vec s_as_ntt) (lift_vec error_as_ntt) + = .ok (lift_vec p.1) ⌝ ⦄ := by + -- Length facts. + have h_t_len : t_as_ntt.length = K.val := Std.Array.length_eq t_as_ntt + have h_err_len : error_as_ntt.length = K.val := Std.Array.length_eq error_as_ntt + -- Re-shape PRE bounds for sub-lemmas (Fin 16 ↔ Nat). + have h_s_bnd' : ∀ k : Fin K.val, ∀ a b : Fin 16, + ((s_as_ntt.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + h_s_bnd + have h_error_bnd' : ∀ k : Fin K.val, ∀ a b : Fin 16, + ((error_as_ntt.val[k.val]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs + ≤ 29439 := h_error_bnd + -- ── S1: row-0 column loop (loop0). acc_init = input `accumulator` (zero). ── + obtain ⟨⟨cache1, acc2⟩, h_loop0_eq, h_row0⟩ := triple_exists_ok_fc + (compute_As_plus_e_loop0_fc matrix_A s_as_ntt s_cache accumulator hAlen + h_matrix_bnd h_s_bnd h_acc_bnd) + dsimp only at h_loop0_eq h_row0 + -- Destructure row0_inv: (1) lane, (2) acc bound, (3) cache populated, (4) cache unchanged. + obtain ⟨_h_row0_lane, h_acc2_bnd_raw, h_cache_done, _h_cache_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Stage1FillCacheFC.row0_inv, ← List.getElem!_eq_getElem?_getD] using h_row0 + -- Cache-post bridge for loop1: row0_inv conjunct (3) at k = K. + have h_cache_post : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (s_as_ntt.val[c]!) (cache1.val[c]!) := by + intro c hc; exact h_cache_done c hc + -- acc2 lane bound: ≤ 2^16·3328 (from acc2[n] ≤ acc_init[n] + K·2^25, acc_init = 0). + have h_acc2_lane_bnd : ∀ n : Nat, n < 256 → + (acc2.val[n]!).val.natAbs ≤ 2^16 * 3328 := by + intro n hn + have hb := h_acc2_bnd_raw n hn + have hz : (accumulator.val[n]!).val.natAbs = 0 := by + rw [h_acc_zero n hn]; rfl + rw [hz] at hb + have hK4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + have h2 : (4 : Nat) * 2^25 ≤ 2^16 * 3328 := by decide + omega + -- ── Row-0 finalize: produces `a` with row-0 lane eq + rows>0 unchanged. ── + obtain ⟨a, h_fin_eq, h_a0_lane, h_a_unch⟩ := triple_exists_ok_fc + (compute_As_plus_e_row0_finalize_fc t_as_ntt matrix_A s_as_ntt error_as_ntt s_cache + accumulator acc2 cache1 hK_pos h_error_bnd h_acc_zero h_acc2_lane_bnd h_row0) + dsimp only at h_fin_eq h_a0_lane h_a_unch + -- ── S2: outer rows loop [1, K). t_as_ntt_init = a. acc seed = acc2 (loop1 re-zeros). ── + obtain ⟨⟨t_as_ntt2, accumulator2⟩, h_loop1_eq, h_rows⟩ := triple_exists_ok_fc + (compute_As_plus_e_loop1_fc a matrix_A s_as_ntt error_as_ntt cache1 acc2 1#usize hK + (by show (1#usize : Std.Usize).val ≤ K.val; exact hK_pos) hAlen + h_matrix_bnd h_s_bnd h_error_bnd h_cache_post) + dsimp only at h_loop1_eq h_rows + -- Destructure rows_inv: (1) done rows [1,K), (2) unchanged rows. + obtain ⟨h_rows_done, h_rows_undone⟩ := by + simpa [Stage3MontStripFC.rows_inv, Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + ← List.getElem!_eq_getElem?_getD] using h_rows + -- t_as_ntt2[0] = a[0] (loop1 starts at 1, leaves row 0 unchanged). + have h_t2_at0 : t_as_ntt2.val[0]! = a.val[0]! := by + exact h_rows_undone 0 (by omega) (Or.inl (by decide)) + -- ── Per-row, per-lane characterization of t_as_ntt2. ── + have h_lane_eq : ∀ r : Nat, r < K.val → ∀ ℓ : Nat, ℓ < 256 → + (lift_poly t_as_ntt2.val[r]!).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (Stage3MontStripFC.canonical_row_sum_lane matrix_A s_as_ntt r (ℓ / 16) (ℓ % 16)) + (lift_fe_mont (1353#i16 : Std.I16))) + ((lift_poly error_as_ntt.val[r]!).val[ℓ]!) := by + intro r hr ℓ hℓ + by_cases h0 : r = 0 + · subst h0 + rw [h_t2_at0] + exact h_a0_lane ℓ hℓ + · have hr1 : (1#usize : Std.Usize).val ≤ r := by + have : (1#usize : Std.Usize).val = 1 := rfl + rw [this]; omega + exact h_rows_done r hr1 hr ℓ hℓ + -- ── PART A: the hacspec equation. ── + have h_hacspec := Stage4MatrixAddFC.hacspec_compute_As_plus_e_eq_of_lane_eq + matrix_A s_as_ntt error_as_ntt t_as_ntt2 hAlen h_lane_eq + -- ── Package: reduce the impl do-block to .ok (t_as_ntt2, cache1, accumulator2). ── + apply triple_of_ok_fc (v := (t_as_ntt2, cache1, accumulator2)) + · unfold libcrux_iot_ml_kem.matrix.compute_As_plus_e + rw [h_loop0_eq]; simp only [Aeneas.Std.bind_tc_ok] + show (do + let s ← Aeneas.Std.lift (Aeneas.Std.Array.to_slice acc2) + let (pre, index_mut_back) ← Aeneas.Std.Array.index_mut_usize t_as_ntt 0#usize + let pre1 ← libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let (pre2, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize (index_mut_back pre1) 0#usize + let pre3 ← Aeneas.Std.Array.index_usize error_as_ntt 0#usize + let pre4 ← libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 + let (t2', accumulator2') ← libcrux_iot_ml_kem.matrix.compute_As_plus_e_loop1 + portable_ops_inst { start := 1#usize, «end» := K } (index_mut_back1 pre4) matrix_A + s_as_ntt error_as_ntt cache1 acc2 + Aeneas.Std.Result.ok (t2', cache1, accumulator2')) + = Aeneas.Std.Result.ok (t_as_ntt2, cache1, accumulator2) + -- Step through binds: invert h_fin_eq step by step to extract per-step equations. + simp only [Aeneas.Std.lift, Aeneas.Std.bind_tc_ok] at h_fin_eq + -- Step 0: index_mut_usize t_as_ntt 0 + cases h0 : Aeneas.Std.Array.index_mut_usize t_as_ntt (0#usize : Std.Usize) with + | fail e => rw [h0] at h_fin_eq; simp at h_fin_eq + | div => rw [h0] at h_fin_eq; simp at h_fin_eq + | ok v0 => + obtain ⟨pre0, imb0⟩ := v0 + simp only [h0, Aeneas.Std.bind_tc_ok] at h_fin_eq + -- Reduce the let-pair destructure (let (a, b) := (x, y)) to concrete form: + change (do + let pre1 ← libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst (Aeneas.Std.Array.to_slice acc2) pre0 + let (pre2', index_mut_back1) ← (imb0 pre1).index_mut_usize 0#usize + let pre3' ← Aeneas.Std.Array.index_usize error_as_ntt 0#usize + let pre4' ← libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2' pre3' + Aeneas.Std.Result.ok (index_mut_back1 pre4')) = Aeneas.Std.Result.ok a at h_fin_eq + -- Step 1: reducing_from_i32_array + cases h1 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst (Aeneas.Std.Array.to_slice acc2) pre0 with + | fail e => rw [h1] at h_fin_eq; simp at h_fin_eq + | div => rw [h1] at h_fin_eq; simp at h_fin_eq + | ok t1 => + simp only [h1, Aeneas.Std.bind_tc_ok] at h_fin_eq + -- Step 2: index_mut_usize (imb0 t1) 0 + cases h2 : Aeneas.Std.Array.index_mut_usize (imb0 t1) (0#usize : Std.Usize) with + | fail e => rw [h2] at h_fin_eq; simp at h_fin_eq + | div => rw [h2] at h_fin_eq; simp at h_fin_eq + | ok v2 => + obtain ⟨pre2, imb1⟩ := v2 + simp only [h2, Aeneas.Std.bind_tc_ok] at h_fin_eq + -- Reduce the let-pair destructure for step 2: + change (do + let pre3' ← Aeneas.Std.Array.index_usize error_as_ntt 0#usize + let pre4' ← libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3' + Aeneas.Std.Result.ok (imb1 pre4')) = Aeneas.Std.Result.ok a at h_fin_eq + -- Step 3: index_usize error_as_ntt 0 + cases h3 : Aeneas.Std.Array.index_usize error_as_ntt (0#usize : Std.Usize) with + | fail e => rw [h3] at h_fin_eq; simp at h_fin_eq + | div => rw [h3] at h_fin_eq; simp at h_fin_eq + | ok pre3 => + simp only [h3, Aeneas.Std.bind_tc_ok] at h_fin_eq + -- Step 4: add_standard_error_reduce + cases h4 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + portable_ops_inst pre2 pre3 with + | fail e => rw [h4] at h_fin_eq; simp at h_fin_eq + | div => rw [h4] at h_fin_eq; simp at h_fin_eq + | ok pre4 => + simp only [h4, Aeneas.Std.bind_tc_ok] at h_fin_eq + -- h_fin_eq : .ok (imb1 pre4) = .ok a → imb1 pre4 = a + have h_a_eq : imb1 pre4 = a := Aeneas.Std.Result.ok.inj h_fin_eq + -- Step through the goal using the same step equations: + simp [Aeneas.Std.lift, Aeneas.Std.bind_tc_ok, h1, h2, h4, h_a_eq, + h_loop1_eq] + · -- POST = the hacspec equation. p.1 = t_as_ntt2. + show hacspec_ml_kem.matrix.compute_As_plus_e + (lift_matrix_from_slice matrix_A K) + (lift_vec s_as_ntt) (lift_vec error_as_ntt) + = .ok (lift_vec t_as_ntt2) + exact h_hacspec + +/-- +info: 'libcrux_iot_ml_kem.Matrix.ComputeAsPlusE.compute_As_plus_e_fc' depends on axioms: [propext, + Classical.choice, + Quot.sound] +-/ +#guard_msgs in +#print axioms compute_As_plus_e_fc + +/- L7.2 — `matrix.compute_vector_u`: product `Aᵀ · r + e₁` of the + encryption step. Proven as + `libcrux_iot_ml_kem.Matrix.ComputeVectorU.FC.compute_vector_u_fc` in + `Matrix/ComputeVectorU/FC.lean`, axiom-clean modulo the + sanctioned `sample_matrix_entry_fc` / `Spec.sample_matrix_A_pure` + boundary. The proof lives downstream because the `L7/` bridge tree + imports `FCTargets`. + + POST: `hacspec_ml_kem.matrix.compute_vector_u (lift_matrix_from_seed seed K) + (lift_vec_slice r_as_ntt K) (lift_vec_slice error_1 K) + = .ok (lift_vec_slice p.2.1 K)`. -/ + +/- L7.3 — `matrix.compute_ring_element_v`: `t · r + e₂ + message` (the + decryption-side ring element `v`). Proven as + `libcrux_iot_ml_kem.Matrix.ComputeRingElementV.FC.compute_ring_element_v_fc` in + `Matrix/ComputeRingElementV/FC.lean` (PRE bounds `hK ≤ 4` / + per-lane `≤ 3328` on `r_as_ntt`/`error_2`/`message` + the + `accumulating_ntt_multiply_poly_cache_post` cache precondition + + zero accumulator). The proof lives downstream because the L7.3 + bridge tree imports `FCTargets`. Axiom-clean modulo the sanctioned + `deserialize_to_reduced_ring_element_fc` (A2) / + `Spec.t_as_ntt_from_public_key_pure` spec-stub boundary. + + POST: `hacspec_ml_kem.matrix.compute_ring_element_v + (lift_t_as_ntt_from_public_key public_key K) (lift_vec_slice r_as_ntt K) + (lift_poly error_2) (lift_poly message) = .ok (lift_poly p.2.1)`. -/ + +/- L7.4 — `matrix.compute_message`: `v - secret · u` then NTT-inverse. + Proven (with explicit PRE bounds `hK ≤ 4` + per-lane `≤ 3328`) as + `libcrux_iot_ml_kem.Matrix.ComputeMessage.FC.compute_message_fc` in + `Matrix/ComputeMessage/FC.lean`. The proof lives downstream + because the L7 bridge tree imports `FCTargets`. + + POST: `hacspec_ml_kem.matrix.compute_message (lift_poly v) (lift_vec secret_as_ntt) + (lift_vec u_as_ntt) = .ok (lift_poly p.1)`. -/ + +/-! ## Roll-up + + Theorems written by layer: + §L0 — 4 + §L1 — 10 + §L2 — 5 + §L3 — 4 (four PortableVector-specialised) + §L6 — 6 (L6.1, L6.2, L6.4, L6.5, L6.6, L6.7) + §L2.8 — 1 (NTT-multiply vector base case, scaffold) + §L6.3 — 1 (NTT-multiply polynomial wrapper, scaffold) + §L7 — 4 + + Total theorems: 35. + Open sorries: 6 proof-level (= 2 prior def stubs + + 4 L7 theorem bodies). scaffolds (L2.8, L6.3, helpers) all + closed at HEAD; the 4 L7 Triples remain open. +-/ + + +end libcrux_iot_ml_kem.Matrix.ComputeAsPlusE \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/Bridges.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/Bridges.lean new file mode 100644 index 00000000..1a41e1ee --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/Bridges.lean @@ -0,0 +1,409 @@ +/- + # `Matrix/ComputeMessage/Bridges.lean` — L7.4 bridge foundation. + + ZMod-domain bridge lemmas for the L7.4 `compute_message` decomposition. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Matrix.Common + +namespace libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges +open libcrux_iot_ml_kem.Matrix.Common +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +/-! ## `zmodOfFE` distribution helpers (public re-derivations). + + FCTargets' `L2_8c.zmodOfFE_{mul,add}_pure` and `Common.lean`'s + `Impl.zmodOfFE_mul_pure` are `private`. We re-expose them publicly. The + `*_val_eq` lemmas re-derive the impl's `% 3329` value equation; `mul`/`add` + are unconditional, `sub` requires canonical inputs. -/ + +/-- Local copy of `Spec.Pure.uscalar_rem_ok_U32` (private there). -/ +private theorem uscalar_rem_ok_U32 (z m : Std.U32) (hm : m.val ≠ 0) : + ∃ w : Std.U32, (z % m : Result Std.U32) = .ok w ∧ w.val = z.val % m.val := by + have heq : (z % m : Result Std.U32) = Std.UScalar.rem z m := rfl + unfold Std.UScalar.rem at heq + simp [hm] at heq + refine ⟨_, heq, ?_⟩ + show (BitVec.umod z.bv m.bv).toNat = z.val % m.val + unfold BitVec.umod + simp only [BitVec.toNat_ofNatLT] + rfl + +private theorem mul_pure_val_eq + (a b : hacspec_ml_kem.parameters.FieldElement) : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b).val.val + = (a.val.val * b.val.val) % 3329 := by + have hmul : + hacspec_ml_kem.parameters.FieldElement.mul a b + = .ok (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b) := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_eq_ok a b + unfold hacspec_ml_kem.parameters.FieldElement.mul at hmul + simp only [Aeneas.Std.lift, Aeneas.Std.bind_tc_ok] at hmul + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Aeneas.Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hae := Std.UScalar.mul_equiv x y + have heqmul : (x * y : Result Std.U32) = Std.UScalar.mul x y := rfl + cases hxy : (x * y : Result Std.U32) with + | ok z => + rw [hxy] at hmul + rw [heqmul] at hxy; rw [hxy] at hae; simp at hae + obtain ⟨_, hzval, _⟩ := hae + simp only [Aeneas.Std.bind_tc_ok] at hmul + have hmod_val : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; simp + have hmod_ne : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val ≠ 0 := by + rw [hmod_val]; decide + set m : Std.U32 := Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U32 z m hmod_ne + rw [hw_eq] at hmul; simp only [Aeneas.Std.bind_tc_ok] at hmul + unfold hacspec_ml_kem.parameters.FieldElement.new at hmul + simp at hmul + have hwbnd : w.val < 3329 := by + rw [hwval, hmod_val]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Aeneas.Std.UScalarTy.numBits]; omega + rw [← hmul] + show (Std.UScalar.cast .U16 w).val = (a.val.val * b.val.val) % 3329 + rw [hwcast, hwval, hmod_val, hzval, hxval, hyval] + | fail _ => + rw [heqmul] at hxy; rw [hxy] at hae + simp only [Std.UScalar.max, Aeneas.Std.UScalarTy.numBits] at hae + rw [hxval, hyval] at hae + have : a.val.val * b.val.val < 2^32 := by + have h1 : a.val.val * b.val.val ≤ (2^16 - 1) * (2^16 - 1) := by + apply Nat.mul_le_mul <;> omega + have heq : (2^16 - 1) * (2^16 - 1) = 2^32 - 2*2^16 + 1 := by decide + omega + omega + | div => rw [heqmul] at hxy; rw [hxy] at hae; exact hae.elim + +/-- `zmodOfFE` distributes over `mul_pure` (public). -/ +theorem zmodOfFE_mul_pure + (a b : hacspec_ml_kem.parameters.FieldElement) : + zmodOfFE (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b) + = zmodOfFE a * zmodOfFE b := by + unfold zmodOfFE + rw [mul_pure_val_eq, ZMod.natCast_mod]; push_cast; rfl + +private theorem add_pure_val_eq + (a b : hacspec_ml_kem.parameters.FieldElement) : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b).val.val + = (a.val.val + b.val.val) % 3329 := by + have hadd : + hacspec_ml_kem.parameters.FieldElement.add a b + = .ok (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b) := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_eq_ok a b + unfold hacspec_ml_kem.parameters.FieldElement.add at hadd + simp only [Aeneas.Std.lift, Aeneas.Std.bind_tc_ok] at hadd + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Aeneas.Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hae := Std.UScalar.add_equiv x y + cases hxy : (x + y) with + | ok z => + rw [hxy] at hae hadd; simp at hae + obtain ⟨_, hzval, _⟩ := hae + simp only [Aeneas.Std.bind_tc_ok] at hadd + have hmod_val : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; simp + have hmod_ne : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val ≠ 0 := by + rw [hmod_val]; decide + set m : Std.U32 := Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U32 z m hmod_ne + rw [hw_eq] at hadd; simp only [Aeneas.Std.bind_tc_ok] at hadd + unfold hacspec_ml_kem.parameters.FieldElement.new at hadd + simp at hadd + have hwbnd : w.val < 3329 := by + rw [hwval, hmod_val]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Aeneas.Std.UScalarTy.numBits]; omega + rw [← hadd] + show (Std.UScalar.cast .U16 w).val = (a.val.val + b.val.val) % 3329 + rw [hwcast, hwval, hmod_val, hzval, hxval, hyval] + | fail e => + rw [hxy] at hae; simp [Std.UScalar.inBounds] at hae + rw [hxval, hyval] at hae; omega + | div => rw [hxy] at hae; exact hae.elim + +/-- `zmodOfFE` distributes over `add_pure` (public). -/ +theorem zmodOfFE_add_pure + (a b : hacspec_ml_kem.parameters.FieldElement) : + zmodOfFE (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b) + = zmodOfFE a + zmodOfFE b := by + unfold zmodOfFE + rw [add_pure_val_eq, ZMod.natCast_mod]; push_cast; rfl + +private theorem sub_pure_val_eq + (a b : hacspec_ml_kem.parameters.FieldElement) + (ha : libcrux_iot_ml_kem.Spec.Pure.Canonical a) + (hb : libcrux_iot_ml_kem.Spec.Pure.Canonical b) : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure a b).val.val + = (a.val.val + 3329 - b.val.val) % 3329 := by + have hsub : + hacspec_ml_kem.parameters.FieldElement.sub a b + = .ok (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure a b) := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_eq_ok a b ha hb + have ha' : a.val.val < 3329 := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at ha + unfold hacspec_ml_kem.parameters.FIELD_MODULUS at ha; simpa using ha + have hb' : b.val.val < 3329 := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at hb + unfold hacspec_ml_kem.parameters.FIELD_MODULUS at hb; simpa using hb + unfold hacspec_ml_kem.parameters.FieldElement.sub at hsub + simp only [Aeneas.Std.lift, Aeneas.Std.bind_tc_ok] at hsub + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Aeneas.Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + set q : Std.U32 := Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hqval : q.val = 3329 := by + show (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val = 3329 + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; simp + have hae := Std.UScalar.add_equiv x q + cases hxq : (x + q : Result Std.U32) with + | ok s => + rw [hxq] at hae hsub; simp at hae + obtain ⟨_, hsval, _⟩ := hae + simp only [Aeneas.Std.bind_tc_ok] at hsub + have hae2 := Std.UScalar.sub_equiv s y + cases hsy : (s - y : Result Std.U32) with + | ok u => + rw [hsy] at hae2 hsub; simp at hae2 + obtain ⟨_hyle, hsuy, _⟩ := hae2 + simp only [Aeneas.Std.bind_tc_ok] at hsub + have hq_ne : q.val ≠ 0 := by rw [hqval]; decide + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U32 u q hq_ne + rw [hw_eq] at hsub; simp only [Aeneas.Std.bind_tc_ok] at hsub + unfold hacspec_ml_kem.parameters.FieldElement.new at hsub + simp at hsub + have hwbnd : w.val < 3329 := by + rw [hwval, hqval]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Aeneas.Std.UScalarTy.numBits]; omega + rw [← hsub] + show (Std.UScalar.cast .U16 w).val = (a.val.val + 3329 - b.val.val) % 3329 + rw [hwcast, hwval, hqval] + have hu_eq : u.val = a.val.val + 3329 - b.val.val := by + have h1 : s.val = u.val + y.val := hsuy + rw [hsval, hxval, hqval, hyval] at h1 + omega + rw [hu_eq] + | fail e => + rw [hsy] at hae2; simp at hae2 + rw [hsval, hxval, hqval, hyval] at hae2 + omega + | div => rw [hsy] at hae2; exact hae2.elim + | fail e => + rw [hxq] at hae; simp [Std.UScalar.inBounds] at hae + rw [hxval, hqval] at hae + omega + | div => rw [hxq] at hae; exact hae.elim + +/-- `zmodOfFE` distributes over `sub_pure` (public; requires canonical inputs, + which all `*_pure` outputs and `lift_fe`/`lift_fe_mont`/`feOfZMod`-built + lanes satisfy). -/ +theorem zmodOfFE_sub_pure + (a b : hacspec_ml_kem.parameters.FieldElement) + (ha : libcrux_iot_ml_kem.Spec.Pure.Canonical a) + (hb : libcrux_iot_ml_kem.Spec.Pure.Canonical b) : + zmodOfFE (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure a b) + = zmodOfFE a - zmodOfFE b := by + have hb' : b.val.val < 3329 := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at hb + unfold hacspec_ml_kem.parameters.FIELD_MODULUS at hb; simpa using hb + unfold zmodOfFE + rw [sub_pure_val_eq a b ha hb, ZMod.natCast_mod] + -- (a + 3329 - b : ℕ) cast to ZMod 3329 = a - b, using b < 3329 so the + -- Nat subtraction does not truncate. + have hcast : ((a.val.val + 3329 - b.val.val : ℕ) : ZMod 3329) + = (a.val.val : ZMod 3329) - (b.val.val : ZMod 3329) := by + have hle : b.val.val ≤ a.val.val + 3329 := by omega + rw [Nat.cast_sub hle] + push_cast + ring + rw [hcast] + +/-! ## `scaleZ` — per-lane `ZMod 3329` scale on FE-arrays. -/ + +/-- Per-lane scale of a 256-FE array by a `ZMod 3329` constant `c`: + lane `j` becomes `feOfZMod (c * zmodOfFE p[j])`. -/ +noncomputable def scaleZ (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Std.Array.make 256#usize + ((List.range 256).map (fun j => feOfZMod (c * zmodOfFE (p.val[j]!)))) + (by simp only [List.length_map, List.length_range]; rfl) + +/-- Lane-access law for `scaleZ`: for `j < 256`, + `zmodOfFE ((scaleZ c p)[j]) = c * zmodOfFE (p[j])`. -/ +theorem scaleZ_lane (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) : + zmodOfFE ((scaleZ c p).val[j]!) = c * zmodOfFE (p.val[j]!) := by + unfold scaleZ + show zmodOfFE (((List.range 256).map + (fun k => feOfZMod (c * zmodOfFE (p.val[k]!))))[j]!) + = c * zmodOfFE (p.val[j]!) + have h_len : ((List.range 256).map + (fun k => feOfZMod (c * zmodOfFE (p.val[k]!)))).length = 256 := by simp + rw [getElem!_pos _ j (by rw [h_len]; exact hj)] + rw [List.getElem_map, List.getElem_range] + exact zmodOfFE_feOfZMod _ + +/-- Composition law: `scaleZ a (scaleZ b p) = scaleZ (a * b) p`. -/ +theorem scaleZ_compose (a b : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + scaleZ a (scaleZ b p) = scaleZ (a * b) p := by + unfold scaleZ + apply Subtype.ext + show (List.range 256).map + (fun j => feOfZMod (a * zmodOfFE ((Std.Array.make 256#usize + ((List.range 256).map (fun k => feOfZMod (b * zmodOfFE (p.val[k]!)))) + (by simp only [List.length_map, List.length_range]; rfl)).val[j]!))) + = (List.range 256).map (fun j => feOfZMod (a * b * zmodOfFE (p.val[j]!))) + apply List.ext_getElem + · simp + · intro j hj1 _hj2 + have hj : j < 256 := by simpa using hj1 + simp only [List.getElem_map, List.getElem_range] + -- inner lane access + have h_len : ((List.range 256).map + (fun k => feOfZMod (b * zmodOfFE (p.val[k]!)))).length = 256 := by simp + have hinner : ((Std.Array.make 256#usize + ((List.range 256).map (fun k => feOfZMod (b * zmodOfFE (p.val[k]!)))) + (by simp only [List.length_map, List.length_range]; rfl)).val[j]!) + = feOfZMod (b * zmodOfFE (p.val[j]!)) := by + show ((List.range 256).map (fun k => feOfZMod (b * zmodOfFE (p.val[k]!))))[j]! + = feOfZMod (b * zmodOfFE (p.val[j]!)) + rw [getElem!_pos _ j (by rw [h_len]; exact hj)] + rw [List.getElem_map, List.getElem_range] + rw [hinner, zmodOfFE_feOfZMod] + congr 1 + ring + +/-! ## Glue arithmetic (all `decide` in `ZMod 3329`). -/ + +theorem glue_3303_2285 : (3303 * 2285 : ZMod 3329) = 512 := by decide +theorem glue_1441_169 : (1441 * 169 : ZMod 3329) = 512 := by decide +theorem glue_169_2285 : (169 * 2285 : ZMod 3329) = 1 := by decide + +/-! ## Chunk / flatten lane-access helpers. + + NB: a *statement type* containing a nested-array index `(xs : List (Std.Array + FE 16#usize))[j]!` fails to elaborate (the `private` FCTargets chunk-`Inhabited` + instance's `by simp` over-solves when forced during type elaboration). We work + around this by stating the generic lemma `mkN_map_lane` with an *abstract* element + type `α` (so no concrete nested type appears in the statement) and only applying + it in *tactic mode* (where the nested index is fine). -/ + +/-- Generic lane access for a `(List.range m).map f`-backed `Std.Array n`: + for `k < m`, `(make n ((range m).map f) hlen).val[k]! = f k`. -/ +private theorem mkN_map_lane {α : Type} [Inhabited α] {n : Std.Usize} {m : Nat} + (f : Nat → α) (k : Nat) (hk : k < m) + (hlen : ((List.range m).map f).length = n.val) : + (Std.Array.make n ((List.range m).map f) hlen).val[k]! = f k := by + show ((List.range m).map f)[k]! = f k + have h_len : ((List.range m).map f).length = m := by simp + rw [getElem!_pos _ k (by rw [h_len]; exact hk)] + simp + +/-- Lane access for `Spec.chunk_at` (flat statement, elaborates fine): for `ℓ < 16`, + `(chunk_at p k)[ℓ] = p[16*k + ℓ]`. -/ +private theorem chunk_at_lane + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (k ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_at p k).val[ℓ]! = p.val[16 * k + ℓ]! := by + unfold Spec.chunk_at + exact mkN_map_lane _ ℓ hℓ _ + +/-- Lane access for a 16-chunk `Std.Array.make` of a mapped `List.range 16` + (`f`/`h` inferred from the goal; matches the call sites in the D proof). -/ +private theorem mk16_chunk_lane {α : Type} [Inhabited α] + (f : Nat → α) (k : Nat) (hk : k < 16) + {h : ((List.range 16).map f).length = (16#usize).val} : + (Std.Array.make 16#usize ((List.range 16).map f) h).val[k]! = f k := + mkN_map_lane f k hk h + +/-- `zmodOfFE (lift_fe_mont x) = x.val · 169` (public; re-derives the + `private` copies in FCTargets/Common). -/ +theorem zmodOfFE_lift_fe_mont (x : Std.I16) : + zmodOfFE (lift_fe_mont x) = (x.val : ZMod 3329) * 169 := by + unfold lift_fe_mont + rw [zmodOfFE_feOfZMod]; rfl + +/-! ## D / subtract bridge (factor 512 = 1441·169). -/ + +/-- Per-lane characterization of `Spec.subtract_reduce_pure`: for `j < 256` + and canonical `a[j]` (the actual `self`-poly lanes are always canonical), + `zmodOfFE ((subtract_reduce_pure a b)[j]) = zmodOfFE (a[j]) - 512 * zmodOfFE (b[j])`. + The impl's fused Montgomery `·1441` correction equals `·512` in `ZMod 3329` + since `1441 · 169 ≡ 512` (`glue_1441_169`). -/ +theorem zmodOfFE_subtract_reduce_pure_lane + (a b : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) + (ha : libcrux_iot_ml_kem.Spec.Pure.Canonical (a.val[j]!)) : + zmodOfFE ((Spec.subtract_reduce_pure a b).val[j]!) + = zmodOfFE (a.val[j]!) - 512 * zmodOfFE (b.val[j]!) := by + have hk : j / 16 < 16 := by omega + have hℓ : j % 16 < 16 := Nat.mod_lt _ (by decide) + have hjeq : 16 * (j / 16) + j % 16 = j := by omega + unfold Spec.subtract_reduce_pure + -- Reduce `(flatten_chunks …).val[j]!` to the nested chunk lookup directly via the + -- generic `mkN_map_lane` (stating a standalone `flatten_chunks_lane` lemma fails: + -- a nested `[!]` index in a *theorem statement type* re-runs the `16#usize` + -- `(by decide)` proof and over-solves — see file header note). + unfold Spec.flatten_chunks + rw [mkN_map_lane _ j hj] + rw [mk16_chunk_lane _ (j / 16) hk] + -- now: chunk_subtract_reduce_pure (chunk_at a k) (chunk_at b k) [j%16] + unfold Spec.chunk_subtract_reduce_pure + rw [mk16_chunk_lane _ (j % 16) hℓ] + -- lane = sub_pure (chunk_at a k)[ℓ] (mul_pure (chunk_at b k)[ℓ] (lift_fe_mont 1441)) + rw [chunk_at_lane a (j / 16) (j % 16) hℓ, chunk_at_lane b (j / 16) (j % 16) hℓ] + rw [hjeq] + -- Need canonicity of the two sub_pure args. + have hcb : libcrux_iot_ml_kem.Spec.Pure.Canonical + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (b.val[j]!) (lift_fe_mont (1441#i16))) := + libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + rw [zmodOfFE_sub_pure _ _ ha hcb] + rw [zmodOfFE_mul_pure] + -- zmodOfFE (lift_fe_mont 1441) = 1441 * 169 = 512 + rw [zmodOfFE_lift_fe_mont] + have h1441 : (((1441#i16 : Std.I16).val : ZMod 3329)) = 1441 := by decide + rw [h1441] + have h512 : (1441 : ZMod 3329) * 169 = 512 := glue_1441_169 + rw [show (zmodOfFE (b.val[j]!) * (1441 * 169) : ZMod 3329) + = 512 * zmodOfFE (b.val[j]!) by rw [h512]; ring] + +end libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/FC.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/FC.lean new file mode 100644 index 00000000..95210506 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/FC.lean @@ -0,0 +1,259 @@ +/- + # `Matrix/ComputeMessage/FC.lean` — L7.4 FC theorem glue. + + Houses the L7.4 FC theorem `compute_message_fc`, gluing the direct + decomposition (impl walk via `triple_*_ok_fc` + the A/B/C/D chain). + + POST: + `hacspec_ml_kem.matrix.compute_message (lift_poly v) + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) = .ok (lift_poly p.1)`. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeMessage.Impl +import LibcruxIotMlKem.Matrix.ComputeMessage.Hacspec + +namespace libcrux_iot_ml_kem.Matrix.ComputeMessage.FC +open libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges libcrux_iot_ml_kem.Matrix.ComputeMessage.Hacspec libcrux_iot_ml_kem.Matrix.ComputeMessage.Impl +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +/-- Local copy of the `private triple_exists_ok_fc` helper (Impl/ComputeMessage): + a `True`-pre Triple yielding `.ok` with the post is an existential witness. -/ +private theorem triple_exists_ok_fc {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- Local copy of the `private triple_of_ok_fc` helper (Impl/ComputeMessage). -/ +private theorem triple_of_ok_fc {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +/-- `scaleZ c p` lanes are `feOfZMod _`, hence canonical (local copy of the + `private canonArr_scaleZ'` in ComputeMessage/Hacspec). -/ +private theorem scaleZ_canon (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) : + libcrux_iot_ml_kem.Spec.Pure.Canonical ((scaleZ c p).val[j]!) := by + unfold scaleZ + show libcrux_iot_ml_kem.Spec.Pure.Canonical + (((List.range 256).map (fun k => feOfZMod (c * zmodOfFE (p.val[k]!))))[j]!) + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical feOfZMod + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] + show (BitVec.ofNat 16 ((c * zmodOfFE (p.val[j]!)).val)).toNat < 3329 + set z := c * zmodOfFE (p.val[j]!) + have h_lt16 : z.val < 2 ^ 16 := by have := ZMod.val_lt z; omega + rw [BitVec.toNat_ofNat, Nat.mod_eq_of_lt h_lt16] + exact ZMod.val_lt _ + +/-- `lift_poly x` lanes are `lift_fe _ = feOfZMod _`, hence canonical. -/ +private theorem lift_poly_canon + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (j : Nat) (hj : j < 256) : + libcrux_iot_ml_kem.Spec.Pure.Canonical ((lift_poly re).val[j]!) := by + unfold lift_poly + show libcrux_iot_ml_kem.Spec.Pure.Canonical + (((List.range 256).map (fun i => + lift_fe (re.coefficients.val[i / 16]!).elements.val[i % 16]!))[j]!) + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + unfold lift_fe libcrux_iot_ml_kem.Spec.Pure.Canonical feOfZMod + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] + show (⟨BitVec.ofNat 16 ((i16_to_spec_fe_plain + (re.coefficients.val[j / 16]!).elements.val[j % 16]!).val)⟩ : Std.U16).val < 3329 + show (BitVec.ofNat 16 ((i16_to_spec_fe_plain + (re.coefficients.val[j / 16]!).elements.val[j % 16]!).val)).toNat < 3329 + set z := i16_to_spec_fe_plain (re.coefficients.val[j / 16]!).elements.val[j % 16]! + have h_lt16 : z.val < 2 ^ 16 := by + have := ZMod.val_lt z; omega + rw [BitVec.toNat_ofNat, Nat.mod_eq_of_lt h_lt16] + exact ZMod.val_lt _ + +/-! ## (iii) — L7.4 FC theorem (glue of (i) + (ii)). + + PRE: `hK : K.val ≤ 4` plus per-lane `≤ 3328` bounds on `secret_as_ntt`, + `u_as_ntt`, `v`. No `h_acc_bnd` (the impl re-zeros the accumulator). -/ +@[spec] +theorem compute_message_fc + {K : Std.Usize} + (v : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (result : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (accumulator : Std.Array Std.I32 256#usize) + (hK : K.val ≤ 4) + (h_secret_bnd : ∀ k : Nat, k < K.val → ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((secret_as_ntt.val[k]!.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328) + (h_u_bnd : ∀ k : Nat, k < K.val → ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((u_as_ntt.val[k]!.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328) + (h_v_bnd : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((v.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_message + (vectortraitsOperationsInst := portable_ops_inst) + v secret_as_ntt u_as_ntt result scratch accumulator + ⦃ ⇓ p => ⌜ hacspec_ml_kem.matrix.compute_message + (lift_poly v) + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) + = .ok (lift_poly p.1) ⌝ ⦄ := by + -- Fin-form bounds for the loop lemma. + have h_secret_fin : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((secret_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := + fun k i j => h_secret_bnd k.val k.isLt i.val i.isLt j.val j.isLt + have h_u_fin : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((u_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := + fun k i j => h_u_bnd k.val k.isLt i.val i.isLt j.val j.isLt + -- Step 0: classify 0#i32 = .ok 0#i32; acc1 = repeat 256 (0#i32) (all-zero). + set acc1 : Std.Array Std.I32 256#usize := + Std.Array.repeat (256#usize : Std.Usize) (0#i32 : Std.I32) with h_acc1_def + have h_acc1_zero : ∀ n : Nat, n < 256 → (acc1.val[n]!).val = 0 := by + intro n hn + rw [h_acc1_def, Std.Array.repeat_val] + rw [getElem!_pos _ n (by rw [List.length_replicate]; exact hn)] + rw [List.getElem_replicate]; rfl + -- Acc budget for the loop: acc1[n] = 0, K ≤ 4, so K·2^25 ≤ 2^30. + have h_acc_budget : ∀ n : Fin 256, + (acc1.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30 := by + intro n + have h0 : (acc1.val[n.val]!).val.natAbs = 0 := by rw [h_acc1_zero n.val n.isLt]; rfl + rw [h0] + have : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + omega + -- S1: run the accumulation loop; get acc2 with the loop invariant. + obtain ⟨acc2, h_acc2_eq, h_char⟩ := triple_exists_ok_fc + (compute_message_loop_fc secret_as_ntt u_as_ntt acc1 + h_secret_fin h_u_fin h_acc_budget) + -- Accumulator bound: acc2[n].natAbs ≤ K·2^25 ≤ 2^27 (from loop_inv conjunct 2). + have h_acc2_bnd : ∀ n : Nat, n < 256 → (acc2.val[n]!).val.natAbs ≤ 2^27 := by + intro n hn + obtain ⟨_, h_inv_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_char + have hb := h_inv_bnd n hn + have h0 : (acc1.val[n]!).val.natAbs = 0 := by rw [h_acc1_zero n hn]; rfl + rw [h0] at hb + have hK4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + have : (2 ^ 27 : Nat) = 4 * 2^25 := by norm_num + omega + -- reducing step: result1. + set s := Aeneas.Std.Array.to_slice acc2 with h_s_def + have h_s_len : s.length = 256 := by + rw [h_s_def, Aeneas.Std.Array.length_to_slice]; rfl + have h_s_bnd : ∀ i : Nat, i < 256 → (s.val[i]!).val.natAbs ≤ 2^16 * 3328 := by + intro i hi + rw [h_s_def, Aeneas.Std.Array.val_to_slice] + have := h_acc2_bnd i hi + have h27 : (2 ^ 27 : Nat) ≤ 2^16 * 3328 := by norm_num + omega + obtain ⟨result1, h_result1_eq, h_result1_mont, h_result1_lane_bnd⟩ := + triple_exists_ok_fc + (poly_reducing_from_i32_array_fc s result h_s_len h_s_bnd) + -- lift_poly result1 = mont_strip (poly_reducing s). + have h_result1_lift : lift_poly result1 + = Impl.mont_strip_pure (Spec.poly_reducing_from_i32_array_pure s) := by + rw [← h_result1_mont, Impl.mont_strip_lift_poly_mont_eq_lift_poly] + -- invert step. PRE ≤13312 from result1 ≤4993. + have h_result1_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((result1.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312 := by + intro chunk hchunk k hk + have := h_result1_lane_bnd chunk hchunk k hk + omega + obtain ⟨⟨result2, scratch1⟩, h_inv_eq, h_result2_lift, h_result2_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_montgomery_fc (K := K) result1 scratch h_result1_bnd) + dsimp only at h_inv_eq h_result2_lift h_result2_bnd + -- subtract step. PRE: v ≤29439, result2 ≤32767. + have h_v_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((v.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro chunk hchunk ℓ hℓ + have := h_v_bnd chunk hchunk ℓ hℓ; omega + have h_result2_b_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((result2.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro chunk hchunk ℓ hℓ + have := h_result2_bnd chunk hchunk ℓ hℓ; omega + obtain ⟨result3, h_sub_eq, h_result3_lift⟩ := + triple_exists_ok_fc + (subtract_reduce_fc v result2 h_v_self_bnd h_result2_b_bnd) + -- Reduce the impl do-block to `.ok (result3, scratch1, acc2)`. + apply triple_of_ok_fc + (v := (result3, scratch1, acc2)) + · unfold libcrux_iot_ml_kem.matrix.compute_message + simp only [libcrux_secrets.traits.Classify.Blanket.classify, Aeneas.Std.lift, + Aeneas.Std.bind_tc_ok] + rw [show (Std.Array.repeat (256#usize : Std.Usize) (0#i32 : Std.I32)) = acc1 from rfl] + rw [h_acc2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [← h_s_def, h_result1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_inv_eq]; simp only [Aeneas.Std.bind_tc_ok] + show (do + let result3 ← polynomial.PolynomialRingElement.subtract_reduce portable_ops_inst v result2 + Aeneas.Std.Result.ok (result3, scratch1, acc2)) = Aeneas.Std.Result.ok (result3, scratch1, acc2) + rw [h_sub_eq]; simp only [Aeneas.Std.bind_tc_ok] + · -- Chain A/B/C/D: prove the hacspec spec = .ok (lift_poly result3). + show hacspec_ml_kem.matrix.compute_message (lift_poly v) + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) = .ok (lift_poly result3) + unfold hacspec_ml_kem.matrix.compute_message + -- A: multiply_vectors = .ok (scaleZ 2285 (lift_poly result1)). + have hA := compute_message_acc_bridge secret_as_ntt u_as_ntt acc1 acc2 + h_acc1_zero h_secret_fin h_u_fin h_char + rw [← h_result1_lift] at hA + rw [hA]; simp only [Aeneas.Std.bind_tc_ok] + -- C: ntt_inverse (scaleZ 2285 (lift_poly result1)) + -- = .ok (scaleZ 3303 (invert_pure (scaleZ 2285 (lift_poly result1)))). + have hCanon_s : ∀ j : Nat, j < 256 → + libcrux_iot_ml_kem.Spec.Pure.Canonical + ((scaleZ 2285 (lift_poly result1)).val[j]!) := + fun j hj => scaleZ_canon 2285 (lift_poly result1) j hj + rw [ntt_inverse_eq_scaleZ_invert_pure (scaleZ 2285 (lift_poly result1)) hCanon_s] + simp only [Aeneas.Std.bind_tc_ok] + -- B: invert_pure (scaleZ 2285 x) = scaleZ 2285 (invert_pure x). + rw [invert_ntt_montgomery_pure_scaleZ 2285 (lift_poly result1) + (fun j hj => lift_poly_canon result1 j hj)] + -- scaleZ 3303 (scaleZ 2285 y) = scaleZ 512 y. + rw [scaleZ_compose 3303 2285 (Spec.invert_ntt_montgomery_pure (lift_poly result1)), + glue_3303_2285] + -- invert_pure (lift_poly result1) = lift_poly result2. + rw [← h_result2_lift] + -- D: sub_polynomials (lift_poly v) (scaleZ 512 (lift_poly result2)) + -- = .ok (subtract_reduce_pure (lift_poly v) (lift_poly result2)). + rw [sub_polynomials_scaleZ_eq (lift_poly v) (lift_poly result2) + (fun j hj => lift_poly_canon v j hj)] + -- subtract_reduce_pure (lift_poly v) (lift_poly result2) = lift_poly result3. + rw [← h_result3_lift] + +/-- +info: 'libcrux_iot_ml_kem.Matrix.ComputeMessage.FC.compute_message_fc' depends on axioms: [propext, + Classical.choice, + Quot.sound] +-/ +#guard_msgs in +#print axioms compute_message_fc + +end libcrux_iot_ml_kem.Matrix.ComputeMessage.FC \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/Hacspec.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/Hacspec.lean new file mode 100644 index 00000000..762a8fe9 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/Hacspec.lean @@ -0,0 +1,2243 @@ +/- + # `Matrix/ComputeMessage/Hacspec.lean` — L7.4 pure/spec bridges. + + Pure-side and pure↔hacspec equational bridges for the direct + `compute_message` decomposition: + + * **B** — `invert_ntt_montgomery_pure` is `scaleZ`-equivariant (pure + scalar-linearity; the constant passes through all 7 inverse layers). + * **C** — `hacspec ntt_inverse p = scaleZ 3303 (invert_ntt_montgomery_pure p)` + (`3303 = 512·169`, the fixed Montgomery factor between the hacspec inverse + NTT and the Mont-domain pure inverse). + * **D (hacspec side)** — `sub_polynomials a (scaleZ 512 b) + = subtract_reduce_pure a b` for canonical `a` (the `·512` cancels the + surplus `R` the invert path carries; `Bridges.zmodOfFE_subtract_reduce_pure_lane` + supplies the per-lane characterization, `1441·169 ≡ 512`). +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeMessage.Bridges + +namespace libcrux_iot_ml_kem.Matrix.ComputeMessage.Hacspec +open libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +/-! ## B — `invert_ntt_montgomery_pure` is `scaleZ`-equivariant. + + The 7-layer inverse NTT is `ZMod`-linear: each Gentleman–Sande butterfly + is `add_pure` / `sub_pure` / `mul_pure`-by-constant on the lanes, all of + which are linear in the lane value. Scaling the input by `c` scales every + output lane by `c`. Both sides are canonical (`Canonical_{add,sub,mul}_pure`), + so the equality holds strictly (not merely per-lane in `ZMod`). + + Proof strategy (per-layer equivariance, compose 7×): + * Per-chunk butterflies: `zmodOfFE (butterfly (scaled inputs)) = + c * zmodOfFE (butterfly (inputs))` via `zmodOfFE_{add,sub,mul}_pure` + + ring; `scaleZ_lane` to expose the `c *` on the input side. + * Each `Spec.invert_ntt_layer_*_pure` preserves the "lane `j` is `c`-scaled" + relation (chunk_at / flatten_chunks lane bookkeeping). + * `Spec.invert_ntt_montgomery_pure` = compose of the 7 layers; the relation + threads through. Finish via canonical determination (two canonical FEs + with equal `zmodOfFE` are equal). + + **Canonical precondition (required).** `Spec.Pure.FieldElement.sub_pure` is + only the linear `a − b` when the subtrahend is canonical (< 3329); on a + non-canonical subtrahend the underlying U32 subtraction underflows and + `sub_pure` saturates to `defaultFE` (residue 0). The layer-1 inverse + butterfly subtracts a *raw input lane*, so equivariance is FALSE for + non-canonical `p` (counterexample: `p[0] = 60000` vs its canonical residue + `78` give different layer-1 outputs). `hp` rules this out; it is discharged + at every L7.4 call site because the inputs are `lift_poly` / `scaleZ` + outputs (canonical by construction). Canonicity is preserved across the + layers by `Canonical_{add,sub,mul}_pure`. -/ +section InvertScaleZ + +open libcrux_iot_ml_kem.Spec.Pure (Canonical) +/-- 256-lane "every lane canonical" predicate. -/ +private def CanonArr (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : Prop := + ∀ j : Nat, j < 256 → Canonical (p.val[j]!) + +/-- 256-lane "`q` is the per-lane `c`-scale of `p` in `ZMod 3329`" predicate. -/ +private def ScaledArr (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : Prop := + ∀ j : Nat, j < 256 → zmodOfFE (q.val[j]!) = c * zmodOfFE (p.val[j]!) + +/-- 16-lane "every lane canonical" predicate. -/ +private def CanonChunk (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : Prop := + ∀ ℓ : Nat, ℓ < 16 → Canonical (a.val[ℓ]!) + +/-- 16-lane "`q` is the per-lane `c`-scale of `p`" predicate. -/ +private def ScaledChunk (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : Prop := + ∀ ℓ : Nat, ℓ < 16 → zmodOfFE (q.val[ℓ]!) = c * zmodOfFE (p.val[ℓ]!) + +/-- `chunk_at` lane access (public re-derivation of the `private` Bridges copy). -/ +private theorem chunk_at_lane' + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (k ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_at p k).val[ℓ]! = p.val[16 * k + ℓ]! := by + unfold Spec.chunk_at + show ((List.range 16).map (fun j => p.val[16 * k + j]!))[ℓ]! = p.val[16 * k + ℓ]! + have h_len : ((List.range 16).map (fun j => p.val[16 * k + j]!)).length = 16 := by simp + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map, List.getElem_range] + +/-- Generic `Std.Array.make … (range m).map f` lane access (local copy). -/ +private theorem mkN_map_lane' {α : Type} [Inhabited α] {n : Std.Usize} {m : Nat} + (f : Nat → α) (k : Nat) (hk : k < m) + (hlen : ((List.range m).map f).length = n.val) : + (Std.Array.make n ((List.range m).map f) hlen).val[k]! = f k := by + show ((List.range m).map f)[k]! = f k + have h_len : ((List.range m).map f).length = m := by simp + rw [getElem!_pos _ k (by rw [h_len]; exact hk)] + simp + +/-- Canonical round-trip (local copy of the `private` FCTargets lemma). -/ +private theorem feOfZMod_zmodOfFE_of_canon + (fe : hacspec_ml_kem.parameters.FieldElement) (h : Canonical fe) : + feOfZMod (zmodOfFE fe) = fe := by + have h' : fe.val.val < 3329 := by + unfold Canonical hacspec_ml_kem.parameters.FIELD_MODULUS at h; simpa using h + unfold feOfZMod zmodOfFE + have hzval : ((fe.val.val : ZMod 3329)).val = fe.val.val := ZMod.val_natCast_of_lt h' + rw [hzval] + have hfeval : fe.val.val < 2 ^ 16 := by + have h_p : (3329 : Nat) ≤ 2 ^ 16 := by decide + omega + have hfebv : BitVec.ofNat 16 fe.val.val = fe.val.bv := by + apply BitVec.eq_of_toNat_eq + rw [BitVec.toNat_ofNat] + show fe.val.val % 2 ^ 16 = fe.val.bv.toNat + rw [Nat.mod_eq_of_lt hfeval]; rfl + show ({ val := ⟨BitVec.ofNat 16 fe.val.val⟩ } : + hacspec_ml_kem.parameters.FieldElement) = fe + rw [hfebv] + +/-! ### Chunk-level equivariance of `chunk_inv_ntt_step_pure`. -/ + +/-- Lane formula for one inverse-NTT chunk step: only lanes `i`/`j` change. -/ +private theorem chunk_inv_ntt_step_lane + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (zeta : hacspec_ml_kem.parameters.FieldElement) (i j : Std.Usize) + (hi : i.val < 16) (hj : j.val < 16) (ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_inv_ntt_step_pure a zeta i j).val[ℓ]! + = if ℓ = j.val then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[j.val]!) (a.val[i.val]!)) zeta + else if ℓ = i.val then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (a.val[j.val]!) (a.val[i.val]!) + else a.val[ℓ]! := by + unfold Spec.chunk_inv_ntt_step_pure + simp only [] -- expose the let-bindings + rw [← Aeneas.Std.Array.getElem!_Nat_eq] + have hlen : a.length = 16 := Aeneas.Std.Array.length_eq a + by_cases hℓj : ℓ = j.val + · rw [if_pos hℓj] + subst hℓj + have hbnd : j.val < (a.set i (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (a.val[j.val]!) (a.val[i.val]!))).length := by + rw [Aeneas.Std.Array.set_length]; rw [hlen]; exact hj + rw [Aeneas.Std.Array.getElem!_Nat_set_eq _ _ _ _ ⟨rfl, hbnd⟩] + · rw [if_neg hℓj, + Aeneas.Std.Array.getElem!_Nat_set_ne _ _ _ _ (fun h => hℓj h.symm)] + by_cases hℓi : ℓ = i.val + · rw [if_pos hℓi] + subst hℓi + rw [Aeneas.Std.Array.getElem!_Nat_set_eq _ _ _ _ ⟨rfl, by rw [hlen]; exact hi⟩] + · rw [if_neg hℓi, + Aeneas.Std.Array.getElem!_Nat_set_ne _ _ _ _ (fun h => hℓi h.symm), + Aeneas.Std.Array.getElem!_Nat_eq] + +/-- `chunk_inv_ntt_step_pure` preserves `CanonChunk`. -/ +private theorem chunk_inv_ntt_step_canon + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (zeta : hacspec_ml_kem.parameters.FieldElement) (i j : Std.Usize) + (hi : i.val < 16) (hj : j.val < 16) (ha : CanonChunk a) : + CanonChunk (Spec.chunk_inv_ntt_step_pure a zeta i j) := by + intro ℓ hℓ + rw [chunk_inv_ntt_step_lane a zeta i j hi hj ℓ hℓ] + by_cases hℓj : ℓ = j.val + · simp only [if_pos hℓj] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + · simp only [if_neg hℓj] + by_cases hℓi : ℓ = i.val + · simp only [if_pos hℓi] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + · simp only [if_neg hℓi]; exact ha ℓ hℓ + +/-- `chunk_inv_ntt_step_pure` preserves `ScaledChunk` (given canonical inputs on + both sides, so the `sub_pure` lanes are genuinely linear). -/ +private theorem chunk_inv_ntt_step_scaled (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (zeta : hacspec_ml_kem.parameters.FieldElement) (i j : Std.Usize) + (hi : i.val < 16) (hj : j.val < 16) + (hq : CanonChunk q) (hpc : CanonChunk p) + (hs : ScaledChunk c q p) : + ScaledChunk c (Spec.chunk_inv_ntt_step_pure q zeta i j) + (Spec.chunk_inv_ntt_step_pure p zeta i j) := by + intro ℓ hℓ + rw [chunk_inv_ntt_step_lane q zeta i j hi hj ℓ hℓ, + chunk_inv_ntt_step_lane p zeta i j hi hj ℓ hℓ] + by_cases hℓj : ℓ = j.val + · simp only [if_pos hℓj] + rw [zmodOfFE_mul_pure, zmodOfFE_mul_pure, + zmodOfFE_sub_pure _ _ (hq j.val hj) (hq i.val hi), + zmodOfFE_sub_pure _ _ (hpc j.val hj) (hpc i.val hi), + hs j.val hj, hs i.val hi] + ring + · simp only [if_neg hℓj] + by_cases hℓi : ℓ = i.val + · simp only [if_pos hℓi] + rw [zmodOfFE_add_pure, zmodOfFE_add_pure, hs j.val hj, hs i.val hi] + ring + · simp only [if_neg hℓi]; exact hs ℓ hℓ + +/-! ### Layer-step (chunk) equivariance for layers 1/2/3. -/ + +/-- `chunk_inv_ntt_layer_1_step_pure` preserves `CanonChunk`. -/ +private theorem chunk_inv_ntt_layer_1_step_canon + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 z2 z3 : hacspec_ml_kem.parameters.FieldElement) (ha : CanonChunk a) : + CanonChunk (Spec.chunk_inv_ntt_layer_1_step_pure a z0 z1 z2 z3) := by + unfold Spec.chunk_inv_ntt_layer_1_step_pure + exact chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) ha))))))) + +/-- `chunk_inv_ntt_layer_2_step_pure` preserves `CanonChunk`. -/ +private theorem chunk_inv_ntt_layer_2_step_canon + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 : hacspec_ml_kem.parameters.FieldElement) (ha : CanonChunk a) : + CanonChunk (Spec.chunk_inv_ntt_layer_2_step_pure a z0 z1) := by + unfold Spec.chunk_inv_ntt_layer_2_step_pure + exact chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) ha))))))) + +/-- `chunk_inv_ntt_layer_3_step_pure` preserves `CanonChunk`. -/ +private theorem chunk_inv_ntt_layer_3_step_canon + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) (ha : CanonChunk a) : + CanonChunk (Spec.chunk_inv_ntt_layer_3_step_pure a z) := by + unfold Spec.chunk_inv_ntt_layer_3_step_pure + exact chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) + (chunk_inv_ntt_step_canon _ _ _ _ (by decide) (by decide) ha))))))) + +/-- `chunk_inv_ntt_layer_1_step_pure` preserves `ScaledChunk`. -/ +private theorem chunk_inv_ntt_layer_1_step_scaled (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 z2 z3 : hacspec_ml_kem.parameters.FieldElement) + (hq : CanonChunk q) (hpc : CanonChunk p) (hs : ScaledChunk c q p) : + ScaledChunk c (Spec.chunk_inv_ntt_layer_1_step_pure q z0 z1 z2 z3) + (Spec.chunk_inv_ntt_layer_1_step_pure p z0 z1 z2 z3) := by + unfold Spec.chunk_inv_ntt_layer_1_step_pure + -- thread canon + scaled through the 8 steps + have c1q := chunk_inv_ntt_step_canon q z0 0#usize 2#usize (by decide) (by decide) hq + have c1p := chunk_inv_ntt_step_canon p z0 0#usize 2#usize (by decide) (by decide) hpc + have s1 := chunk_inv_ntt_step_scaled c q p z0 0#usize 2#usize (by decide) (by decide) hq hpc hs + have c2q := chunk_inv_ntt_step_canon _ z0 1#usize 3#usize (by decide) (by decide) c1q + have c2p := chunk_inv_ntt_step_canon _ z0 1#usize 3#usize (by decide) (by decide) c1p + have s2 := chunk_inv_ntt_step_scaled c _ _ z0 1#usize 3#usize (by decide) (by decide) c1q c1p s1 + have c3q := chunk_inv_ntt_step_canon _ z1 4#usize 6#usize (by decide) (by decide) c2q + have c3p := chunk_inv_ntt_step_canon _ z1 4#usize 6#usize (by decide) (by decide) c2p + have s3 := chunk_inv_ntt_step_scaled c _ _ z1 4#usize 6#usize (by decide) (by decide) c2q c2p s2 + have c4q := chunk_inv_ntt_step_canon _ z1 5#usize 7#usize (by decide) (by decide) c3q + have c4p := chunk_inv_ntt_step_canon _ z1 5#usize 7#usize (by decide) (by decide) c3p + have s4 := chunk_inv_ntt_step_scaled c _ _ z1 5#usize 7#usize (by decide) (by decide) c3q c3p s3 + have c5q := chunk_inv_ntt_step_canon _ z2 8#usize 10#usize (by decide) (by decide) c4q + have c5p := chunk_inv_ntt_step_canon _ z2 8#usize 10#usize (by decide) (by decide) c4p + have s5 := chunk_inv_ntt_step_scaled c _ _ z2 8#usize 10#usize (by decide) (by decide) c4q c4p s4 + have c6q := chunk_inv_ntt_step_canon _ z2 9#usize 11#usize (by decide) (by decide) c5q + have c6p := chunk_inv_ntt_step_canon _ z2 9#usize 11#usize (by decide) (by decide) c5p + have s6 := chunk_inv_ntt_step_scaled c _ _ z2 9#usize 11#usize (by decide) (by decide) c5q c5p s5 + have c7q := chunk_inv_ntt_step_canon _ z3 12#usize 14#usize (by decide) (by decide) c6q + have c7p := chunk_inv_ntt_step_canon _ z3 12#usize 14#usize (by decide) (by decide) c6p + have s7 := chunk_inv_ntt_step_scaled c _ _ z3 12#usize 14#usize (by decide) (by decide) c6q c6p s6 + exact chunk_inv_ntt_step_scaled c _ _ z3 13#usize 15#usize (by decide) (by decide) c7q c7p s7 + +/-- `chunk_inv_ntt_layer_2_step_pure` preserves `ScaledChunk`. -/ +private theorem chunk_inv_ntt_layer_2_step_scaled (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 : hacspec_ml_kem.parameters.FieldElement) + (hq : CanonChunk q) (hpc : CanonChunk p) (hs : ScaledChunk c q p) : + ScaledChunk c (Spec.chunk_inv_ntt_layer_2_step_pure q z0 z1) + (Spec.chunk_inv_ntt_layer_2_step_pure p z0 z1) := by + unfold Spec.chunk_inv_ntt_layer_2_step_pure + have c1q := chunk_inv_ntt_step_canon q z0 0#usize 4#usize (by decide) (by decide) hq + have c1p := chunk_inv_ntt_step_canon p z0 0#usize 4#usize (by decide) (by decide) hpc + have s1 := chunk_inv_ntt_step_scaled c q p z0 0#usize 4#usize (by decide) (by decide) hq hpc hs + have c2q := chunk_inv_ntt_step_canon _ z0 1#usize 5#usize (by decide) (by decide) c1q + have c2p := chunk_inv_ntt_step_canon _ z0 1#usize 5#usize (by decide) (by decide) c1p + have s2 := chunk_inv_ntt_step_scaled c _ _ z0 1#usize 5#usize (by decide) (by decide) c1q c1p s1 + have c3q := chunk_inv_ntt_step_canon _ z0 2#usize 6#usize (by decide) (by decide) c2q + have c3p := chunk_inv_ntt_step_canon _ z0 2#usize 6#usize (by decide) (by decide) c2p + have s3 := chunk_inv_ntt_step_scaled c _ _ z0 2#usize 6#usize (by decide) (by decide) c2q c2p s2 + have c4q := chunk_inv_ntt_step_canon _ z0 3#usize 7#usize (by decide) (by decide) c3q + have c4p := chunk_inv_ntt_step_canon _ z0 3#usize 7#usize (by decide) (by decide) c3p + have s4 := chunk_inv_ntt_step_scaled c _ _ z0 3#usize 7#usize (by decide) (by decide) c3q c3p s3 + have c5q := chunk_inv_ntt_step_canon _ z1 8#usize 12#usize (by decide) (by decide) c4q + have c5p := chunk_inv_ntt_step_canon _ z1 8#usize 12#usize (by decide) (by decide) c4p + have s5 := chunk_inv_ntt_step_scaled c _ _ z1 8#usize 12#usize (by decide) (by decide) c4q c4p s4 + have c6q := chunk_inv_ntt_step_canon _ z1 9#usize 13#usize (by decide) (by decide) c5q + have c6p := chunk_inv_ntt_step_canon _ z1 9#usize 13#usize (by decide) (by decide) c5p + have s6 := chunk_inv_ntt_step_scaled c _ _ z1 9#usize 13#usize (by decide) (by decide) c5q c5p s5 + have c7q := chunk_inv_ntt_step_canon _ z1 10#usize 14#usize (by decide) (by decide) c6q + have c7p := chunk_inv_ntt_step_canon _ z1 10#usize 14#usize (by decide) (by decide) c6p + have s7 := chunk_inv_ntt_step_scaled c _ _ z1 10#usize 14#usize (by decide) (by decide) c6q c6p s6 + exact chunk_inv_ntt_step_scaled c _ _ z1 11#usize 15#usize (by decide) (by decide) c7q c7p s7 + +/-- `chunk_inv_ntt_layer_3_step_pure` preserves `ScaledChunk`. -/ +private theorem chunk_inv_ntt_layer_3_step_scaled (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) + (hq : CanonChunk q) (hpc : CanonChunk p) (hs : ScaledChunk c q p) : + ScaledChunk c (Spec.chunk_inv_ntt_layer_3_step_pure q z) + (Spec.chunk_inv_ntt_layer_3_step_pure p z) := by + unfold Spec.chunk_inv_ntt_layer_3_step_pure + have c1q := chunk_inv_ntt_step_canon q z 0#usize 8#usize (by decide) (by decide) hq + have c1p := chunk_inv_ntt_step_canon p z 0#usize 8#usize (by decide) (by decide) hpc + have s1 := chunk_inv_ntt_step_scaled c q p z 0#usize 8#usize (by decide) (by decide) hq hpc hs + have c2q := chunk_inv_ntt_step_canon _ z 1#usize 9#usize (by decide) (by decide) c1q + have c2p := chunk_inv_ntt_step_canon _ z 1#usize 9#usize (by decide) (by decide) c1p + have s2 := chunk_inv_ntt_step_scaled c _ _ z 1#usize 9#usize (by decide) (by decide) c1q c1p s1 + have c3q := chunk_inv_ntt_step_canon _ z 2#usize 10#usize (by decide) (by decide) c2q + have c3p := chunk_inv_ntt_step_canon _ z 2#usize 10#usize (by decide) (by decide) c2p + have s3 := chunk_inv_ntt_step_scaled c _ _ z 2#usize 10#usize (by decide) (by decide) c2q c2p s2 + have c4q := chunk_inv_ntt_step_canon _ z 3#usize 11#usize (by decide) (by decide) c3q + have c4p := chunk_inv_ntt_step_canon _ z 3#usize 11#usize (by decide) (by decide) c3p + have s4 := chunk_inv_ntt_step_scaled c _ _ z 3#usize 11#usize (by decide) (by decide) c3q c3p s3 + have c5q := chunk_inv_ntt_step_canon _ z 4#usize 12#usize (by decide) (by decide) c4q + have c5p := chunk_inv_ntt_step_canon _ z 4#usize 12#usize (by decide) (by decide) c4p + have s5 := chunk_inv_ntt_step_scaled c _ _ z 4#usize 12#usize (by decide) (by decide) c4q c4p s4 + have c6q := chunk_inv_ntt_step_canon _ z 5#usize 13#usize (by decide) (by decide) c5q + have c6p := chunk_inv_ntt_step_canon _ z 5#usize 13#usize (by decide) (by decide) c5p + have s6 := chunk_inv_ntt_step_scaled c _ _ z 5#usize 13#usize (by decide) (by decide) c5q c5p s5 + have c7q := chunk_inv_ntt_step_canon _ z 6#usize 14#usize (by decide) (by decide) c6q + have c7p := chunk_inv_ntt_step_canon _ z 6#usize 14#usize (by decide) (by decide) c6p + have s7 := chunk_inv_ntt_step_scaled c _ _ z 6#usize 14#usize (by decide) (by decide) c6q c6p s6 + exact chunk_inv_ntt_step_scaled c _ _ z 7#usize 15#usize (by decide) (by decide) c7q c7p s7 + +/-! ### Lift chunk preservation to the 256-array layers 1/2/3. -/ + +/-- Lane access for a 16-chunk flatten shape: lane `j` of + `flatten_chunks (make 16 ((range 16).map H))` is `(H (j/16)).val[j%16]!`. -/ +private theorem flatten_chunk_map_lane + (H : Nat → Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (j : Nat) (hj : j < 256) + (h : ((List.range 16).map H).length = (16#usize).val) : + (Spec.flatten_chunks (Std.Array.make 16#usize ((List.range 16).map H) h)).val[j]! + = (H (j / 16)).val[j % 16]! := by + have hk : j / 16 < 16 := by omega + unfold Spec.flatten_chunks + rw [mkN_map_lane' _ j hj] + rw [mkN_map_lane' H (j / 16) hk] + +/-- A `chunk_step`-mapped layer preserves `CanonArr`, given the chunk step + preserves `CanonChunk` (with zetas possibly depending on chunk index `k`). -/ +private theorem layer_canon_of_chunk_canon + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (G : Nat → Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize → + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (hp : CanonArr p) + (hG : ∀ (k : Nat) (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize), + CanonChunk a → CanonChunk (G k a)) + (h : ((List.range 16).map (fun k => G k (Spec.chunk_at p k))).length = (16#usize).val) : + CanonArr (Spec.flatten_chunks (Std.Array.make 16#usize + ((List.range 16).map (fun k => G k (Spec.chunk_at p k))) h)) := by + intro j hj + rw [flatten_chunk_map_lane (fun k => G k (Spec.chunk_at p k)) j hj h] + apply hG (j / 16) _ _ (j % 16) (Nat.mod_lt _ (by decide)) + intro ℓ hℓ + rw [chunk_at_lane' p (j / 16) ℓ hℓ] + apply hp + have hk : j / 16 < 16 := by omega + omega + +/-- `chunk_at` of a `CanonArr` is `CanonChunk`. -/ +private theorem canonChunk_chunk_at + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hp : CanonArr p) (k : Nat) (hk : k < 16) : CanonChunk (Spec.chunk_at p k) := by + intro ℓ hℓ + rw [chunk_at_lane' p k ℓ hℓ] + exact hp _ (by omega) + +/-- `chunk_at` is scale-compatible: `ScaledArr c q p → ScaledChunk c (chunk_at q k) (chunk_at p k)`. -/ +private theorem scaledChunk_chunk_at (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hs : ScaledArr c q p) (k : Nat) (hk : k < 16) : + ScaledChunk c (Spec.chunk_at q k) (Spec.chunk_at p k) := by + intro ℓ hℓ + rw [chunk_at_lane' q k ℓ hℓ, chunk_at_lane' p k ℓ hℓ] + exact hs _ (by omega) + +/-- A `chunk_step`-mapped layer preserves `ScaledArr`, given the chunk step + preserves `ScaledChunk` (using canonicity of both `chunk_at` sides). -/ +private theorem layer_scaled_of_chunk_scaled (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (G : Nat → Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize → + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (hq : CanonArr q) (hpc : CanonArr p) (hs : ScaledArr c q p) + (hG : ∀ (k : Nat) (qc pc : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize), + CanonChunk qc → CanonChunk pc → ScaledChunk c qc pc → ScaledChunk c (G k qc) (G k pc)) + (hlq : ((List.range 16).map (fun k => G k (Spec.chunk_at q k))).length = (16#usize).val) + (hlp : ((List.range 16).map (fun k => G k (Spec.chunk_at p k))).length = (16#usize).val) : + ScaledArr c + (Spec.flatten_chunks (Std.Array.make 16#usize + ((List.range 16).map (fun k => G k (Spec.chunk_at q k))) hlq)) + (Spec.flatten_chunks (Std.Array.make 16#usize + ((List.range 16).map (fun k => G k (Spec.chunk_at p k))) hlp)) := by + intro j hj + rw [flatten_chunk_map_lane (fun k => G k (Spec.chunk_at q k)) j hj hlq, + flatten_chunk_map_lane (fun k => G k (Spec.chunk_at p k)) j hj hlp] + have hk : j / 16 < 16 := by omega + exact hG (j / 16) _ _ + (canonChunk_chunk_at q hq (j / 16) hk) (canonChunk_chunk_at p hpc (j / 16) hk) + (scaledChunk_chunk_at c q p hs (j / 16) hk) (j % 16) (Nat.mod_lt _ (by decide)) + +/-! ### Array-level layers 1/2/3: canon + scaled preservation. -/ + +private theorem invert_ntt_layer_1_canon + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) (hp : CanonArr p) : + CanonArr (Spec.invert_ntt_layer_1_pure p zeta_i) := by + unfold Spec.invert_ntt_layer_1_pure + exact layer_canon_of_chunk_canon p + (fun k a => Spec.chunk_inv_ntt_layer_1_step_pure a + (Spec.zeta_at (zeta_i.val - 4 * k - 1)) (Spec.zeta_at (zeta_i.val - 4 * k - 2)) + (Spec.zeta_at (zeta_i.val - 4 * k - 3)) (Spec.zeta_at (zeta_i.val - 4 * k - 4))) + hp (fun k a ha => chunk_inv_ntt_layer_1_step_canon a _ _ _ _ ha) _ + +private theorem invert_ntt_layer_2_canon + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) (hp : CanonArr p) : + CanonArr (Spec.invert_ntt_layer_2_pure p zeta_i) := by + unfold Spec.invert_ntt_layer_2_pure + exact layer_canon_of_chunk_canon p + (fun k a => Spec.chunk_inv_ntt_layer_2_step_pure a + (Spec.zeta_at (zeta_i.val - 2 * k - 1)) (Spec.zeta_at (zeta_i.val - 2 * k - 2))) + hp (fun k a ha => chunk_inv_ntt_layer_2_step_canon a _ _ ha) _ + +private theorem invert_ntt_layer_3_canon + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) (hp : CanonArr p) : + CanonArr (Spec.invert_ntt_layer_3_pure p zeta_i) := by + unfold Spec.invert_ntt_layer_3_pure + exact layer_canon_of_chunk_canon p + (fun k a => Spec.chunk_inv_ntt_layer_3_step_pure a (Spec.zeta_at (zeta_i.val - k - 1))) + hp (fun k a ha => chunk_inv_ntt_layer_3_step_canon a _ ha) _ + +private theorem invert_ntt_layer_1_scaled (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) (hq : CanonArr q) (hpc : CanonArr p) (hs : ScaledArr c q p) : + ScaledArr c (Spec.invert_ntt_layer_1_pure q zeta_i) (Spec.invert_ntt_layer_1_pure p zeta_i) := by + unfold Spec.invert_ntt_layer_1_pure + exact layer_scaled_of_chunk_scaled c q p + (fun k a => Spec.chunk_inv_ntt_layer_1_step_pure a + (Spec.zeta_at (zeta_i.val - 4 * k - 1)) (Spec.zeta_at (zeta_i.val - 4 * k - 2)) + (Spec.zeta_at (zeta_i.val - 4 * k - 3)) (Spec.zeta_at (zeta_i.val - 4 * k - 4))) + hq hpc hs (fun k qc pc hqc hpc' hsc => + chunk_inv_ntt_layer_1_step_scaled c qc pc _ _ _ _ hqc hpc' hsc) _ _ + +private theorem invert_ntt_layer_2_scaled (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) (hq : CanonArr q) (hpc : CanonArr p) (hs : ScaledArr c q p) : + ScaledArr c (Spec.invert_ntt_layer_2_pure q zeta_i) (Spec.invert_ntt_layer_2_pure p zeta_i) := by + unfold Spec.invert_ntt_layer_2_pure + exact layer_scaled_of_chunk_scaled c q p + (fun k a => Spec.chunk_inv_ntt_layer_2_step_pure a + (Spec.zeta_at (zeta_i.val - 2 * k - 1)) (Spec.zeta_at (zeta_i.val - 2 * k - 2))) + hq hpc hs (fun k qc pc hqc hpc' hsc => + chunk_inv_ntt_layer_2_step_scaled c qc pc _ _ hqc hpc' hsc) _ _ + +private theorem invert_ntt_layer_3_scaled (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) (hq : CanonArr q) (hpc : CanonArr p) (hs : ScaledArr c q p) : + ScaledArr c (Spec.invert_ntt_layer_3_pure q zeta_i) (Spec.invert_ntt_layer_3_pure p zeta_i) := by + unfold Spec.invert_ntt_layer_3_pure + exact layer_scaled_of_chunk_scaled c q p + (fun k a => Spec.chunk_inv_ntt_layer_3_step_pure a (Spec.zeta_at (zeta_i.val - k - 1))) + hq hpc hs (fun k qc pc hqc hpc' hsc => + chunk_inv_ntt_layer_3_step_scaled c qc pc _ hqc hpc' hsc) _ _ + +/-! ### Cross-chunk butterflies (layers 4-7). -/ + +/-- Lane formula for `chunk_inv_pair_butterfly_a_pure`. -/ +private theorem chunk_inv_pair_butterfly_a_lane + (ca cb : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_inv_pair_butterfly_a_pure ca cb).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (ca.val[ℓ]!) (cb.val[ℓ]!) := by + unfold Spec.chunk_inv_pair_butterfly_a_pure + exact mkN_map_lane' _ ℓ hℓ _ + +/-- Lane formula for `chunk_inv_pair_butterfly_b_pure`. -/ +private theorem chunk_inv_pair_butterfly_b_lane + (ca cb : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) (ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_inv_pair_butterfly_b_pure ca cb z).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (cb.val[ℓ]!) (ca.val[ℓ]!)) z := by + unfold Spec.chunk_inv_pair_butterfly_b_pure + exact mkN_map_lane' _ ℓ hℓ _ + +private theorem chunk_inv_pair_butterfly_a_canon + (ca cb : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + CanonChunk (Spec.chunk_inv_pair_butterfly_a_pure ca cb) := by + intro ℓ hℓ + rw [chunk_inv_pair_butterfly_a_lane ca cb ℓ hℓ] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + +private theorem chunk_inv_pair_butterfly_b_canon + (ca cb : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) : + CanonChunk (Spec.chunk_inv_pair_butterfly_b_pure ca cb z) := by + intro ℓ hℓ + rw [chunk_inv_pair_butterfly_b_lane ca cb z ℓ hℓ] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + +private theorem chunk_inv_pair_butterfly_a_scaled (c : ZMod 3329) + (qa qb pa pb : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (hsa : ScaledChunk c qa pa) (hsb : ScaledChunk c qb pb) : + ScaledChunk c (Spec.chunk_inv_pair_butterfly_a_pure qa qb) + (Spec.chunk_inv_pair_butterfly_a_pure pa pb) := by + intro ℓ hℓ + rw [chunk_inv_pair_butterfly_a_lane qa qb ℓ hℓ, + chunk_inv_pair_butterfly_a_lane pa pb ℓ hℓ, + zmodOfFE_add_pure, zmodOfFE_add_pure, hsa ℓ hℓ, hsb ℓ hℓ] + ring + +private theorem chunk_inv_pair_butterfly_b_scaled (c : ZMod 3329) + (qa qb pa pb : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) + (hqa : CanonChunk qa) (hqb : CanonChunk qb) (hpa : CanonChunk pa) (hpb : CanonChunk pb) + (hsa : ScaledChunk c qa pa) (hsb : ScaledChunk c qb pb) : + ScaledChunk c (Spec.chunk_inv_pair_butterfly_b_pure qa qb z) + (Spec.chunk_inv_pair_butterfly_b_pure pa pb z) := by + intro ℓ hℓ + rw [chunk_inv_pair_butterfly_b_lane qa qb z ℓ hℓ, + chunk_inv_pair_butterfly_b_lane pa pb z ℓ hℓ, + zmodOfFE_mul_pure, zmodOfFE_mul_pure, + zmodOfFE_sub_pure _ _ (hqb ℓ hℓ) (hqa ℓ hℓ), + zmodOfFE_sub_pure _ _ (hpb ℓ hℓ) (hpa ℓ hℓ), + hsa ℓ hℓ, hsb ℓ hℓ] + ring + +/-! ### Array-level layer 4+ preservation. -/ + +/-- `chunks0` lane access: `(make 16 (map (chunk_at p))).val[k]! = chunk_at p k`. -/ +private theorem chunks0_lane + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (k : Nat) (hk : k < 16) + (h : ((List.range 16).map (Spec.chunk_at p)).length = (16#usize).val) : + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at p)) h).val[k]! + = Spec.chunk_at p k := + mkN_map_lane' (Spec.chunk_at p) k hk h + +/-- `chunk_inv_at_layer_4_plus_pure` preserves `CanonChunk` per output chunk. -/ +private theorem chunk_inv_at_layer_4_plus_canon + (chunks : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize) + (layer : Std.Usize) (zeta_fn : Nat → hacspec_ml_kem.parameters.FieldElement) (cc : Nat) : + CanonChunk (Spec.chunk_inv_at_layer_4_plus_pure chunks layer zeta_fn cc) := by + unfold Spec.chunk_inv_at_layer_4_plus_pure + by_cases h : cc % (2 * ((1 <<< layer.val) / 16)) < (1 <<< layer.val) / 16 + · simp only [h, if_true] + exact chunk_inv_pair_butterfly_a_canon _ _ + · simp only [h, if_false] + exact chunk_inv_pair_butterfly_b_canon _ _ _ + +/-- Body equation for one output chunk of layer 4+ (avoids any nested chunk + `[!]` in a statement type by phrasing over the `chunks0 = make 16 (map chunk_at)` + array built from a 256-array `p`). Reduces `chunks0.val[k]!` to `chunk_at p k`. + Requires `cc < 16` and, in the a-branch, `cc + step < 16` (always true by + construction; passed in as `hub`). -/ +private theorem chunk_inv_at_layer_4_plus_chunks0_eq + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (layer : Std.Usize) (zeta_fn : Nat → hacspec_ml_kem.parameters.FieldElement) (cc : Nat) + (hcc : cc < 16) + (hub : cc % (2 * ((1 <<< layer.val) / 16)) < (1 <<< layer.val) / 16 → + cc + (1 <<< layer.val) / 16 < 16) + (h : ((List.range 16).map (Spec.chunk_at p)).length = (16#usize).val) : + Spec.chunk_inv_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at p)) h) layer zeta_fn cc + = if cc % (2 * ((1 <<< layer.val) / 16)) < (1 <<< layer.val) / 16 then + Spec.chunk_inv_pair_butterfly_a_pure (Spec.chunk_at p cc) + (Spec.chunk_at p (cc + (1 <<< layer.val) / 16)) + else + Spec.chunk_inv_pair_butterfly_b_pure (Spec.chunk_at p (cc - (1 <<< layer.val) / 16)) + (Spec.chunk_at p cc) (zeta_fn (cc / (2 * ((1 <<< layer.val) / 16)))) := by + unfold Spec.chunk_inv_at_layer_4_plus_pure + by_cases hb : cc % (2 * ((1 <<< layer.val) / 16)) < (1 <<< layer.val) / 16 + · rw [if_pos hb, if_pos hb] + have hub' : cc + (1 <<< layer.val) / 16 < 16 := hub hb + rw [chunks0_lane p cc hcc h, chunks0_lane p (cc + (1 <<< layer.val) / 16) hub' h] + · rw [if_neg hb, if_neg hb] + have hsub : cc - (1 <<< layer.val) / 16 < 16 := Nat.lt_of_le_of_lt (Nat.sub_le _ _) hcc + rw [chunks0_lane p cc hcc h, chunks0_lane p (cc - (1 <<< layer.val) / 16) hsub h] + +/-- For `cc < 16`, `2*step ∣ 16`, and `cc % (2*step) < step`, we have + `cc + step < 16` (the partner chunk stays in range). -/ +private theorem layer4_partner_lt + (cc step : Nat) (hcc : cc < 16) (_hstep : 0 < step) (hdvd : (2 * step) ∣ 16) + (hoff : cc % (2 * step) < step) : cc + step < 16 := by + obtain ⟨t, ht⟩ := hdvd + set Q := cc / (2 * step) with hQ + set r := cc % (2 * step) with hr + have ht16 : (2 * step) * t = 16 := ht.symm + have hblock : Q < t := by + apply Nat.div_lt_of_lt_mul; rw [ht16]; exact hcc + have hdm : cc = (2 * step) * Q + r := (Nat.div_add_mod cc (2 * step)).symm + calc cc + step < (2 * step) * Q + 2 * step := by omega + _ = (2 * step) * (Q + 1) := by ring + _ ≤ (2 * step) * t := by apply Nat.mul_le_mul_left; omega + _ = 16 := ht16 + +/-- `invert_ntt_layer_4_plus_pure` preserves `CanonArr`. -/ +private theorem invert_ntt_layer_4_plus_canon + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i layer : Std.Usize) (_hp : CanonArr p) : + CanonArr (Spec.invert_ntt_layer_4_plus_pure p zeta_i layer) := by + intro j hj + unfold Spec.invert_ntt_layer_4_plus_pure + rw [flatten_chunk_map_lane (fun cc => Spec.chunk_inv_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at p)) (by simp)) + layer (fun group => Spec.zeta_at (zeta_i.val - 1 - group)) cc) j hj (by simp)] + apply chunk_inv_at_layer_4_plus_canon _ _ _ _ (j % 16) (Nat.mod_lt _ (by decide)) + +/-- Scaled-preservation of one output chunk of layer 4+, phrased over 256-arrays + `q p` and chunk index `cc` (uses `chunk_inv_at_layer_4_plus_chunks0_eq`). -/ +private theorem chunk_inv_at_layer_4_plus_scaled_chunks0 (c : ZMod 3329) + (q p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (layer : Std.Usize) (zfn : Nat → hacspec_ml_kem.parameters.FieldElement) (cc : Nat) + (hcc : cc < 16) + (hub : cc % (2 * ((1 <<< layer.val) / 16)) < (1 <<< layer.val) / 16 → + cc + (1 <<< layer.val) / 16 < 16) + (hq : CanonArr q) (hpc : CanonArr p) (hs : ScaledArr c q p) + (hqlen : ((List.range 16).map (Spec.chunk_at q)).length = (16#usize).val) + (hplen : ((List.range 16).map (Spec.chunk_at p)).length = (16#usize).val) : + ScaledChunk c + (Spec.chunk_inv_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at q)) hqlen) layer zfn cc) + (Spec.chunk_inv_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at p)) hplen) layer zfn cc) := by + rw [chunk_inv_at_layer_4_plus_chunks0_eq q layer zfn cc hcc hub hqlen, + chunk_inv_at_layer_4_plus_chunks0_eq p layer zfn cc hcc hub hplen] + by_cases hbr : cc % (2 * ((1 <<< layer.val) / 16)) < (1 <<< layer.val) / 16 + · rw [if_pos hbr, if_pos hbr] + have hub' := hub hbr + exact chunk_inv_pair_butterfly_a_scaled c _ _ _ _ + (scaledChunk_chunk_at c q p hs cc hcc) + (scaledChunk_chunk_at c q p hs _ hub') + · rw [if_neg hbr, if_neg hbr] + have hsub : cc - (1 <<< layer.val) / 16 < 16 := Nat.lt_of_le_of_lt (Nat.sub_le _ _) hcc + exact chunk_inv_pair_butterfly_b_scaled c _ _ _ _ _ + (canonChunk_chunk_at q hq _ hsub) (canonChunk_chunk_at q hq cc hcc) + (canonChunk_chunk_at p hpc _ hsub) (canonChunk_chunk_at p hpc cc hcc) + (scaledChunk_chunk_at c q p hs _ hsub) + (scaledChunk_chunk_at c q p hs cc hcc) + +/-- `invert_ntt_layer_4_plus_pure` preserves `ScaledArr`, given the step + `(1<< Spec.chunk_inv_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at q)) (by simp)) layer + (fun group => Spec.zeta_at (zeta_i.val - 1 - group)) cc) j hj (by simp), + flatten_chunk_map_lane (fun cc => Spec.chunk_inv_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at p)) (by simp)) layer + (fun group => Spec.zeta_at (zeta_i.val - 1 - group)) cc) j hj (by simp)] + have hcc : j / 16 < 16 := by omega + have hub : (j / 16) % (2 * ((1 <<< layer.val) / 16)) < (1 <<< layer.val) / 16 → + (j / 16) + (1 <<< layer.val) / 16 < 16 := + fun hoff => layer4_partner_lt (j / 16) _ hcc hstep hdvd hoff + exact chunk_inv_at_layer_4_plus_scaled_chunks0 c q p layer _ (j / 16) hcc hub hq hpc hs _ _ + (j % 16) (Nat.mod_lt _ (by decide)) + +/-! ### Final assembly. -/ + +/-- `feOfZMod z` is canonical. -/ +private theorem canon_feOfZMod (z : ZMod 3329) : Canonical (feOfZMod z) := by + unfold Canonical feOfZMod hacspec_ml_kem.parameters.FIELD_MODULUS + show (BitVec.ofNat 16 z.val).toNat < _ + rw [BitVec.toNat_ofNat] + have hz : z.val < 3329 := ZMod.val_lt z + have : z.val % 2 ^ 16 = z.val := Nat.mod_eq_of_lt (by omega) + simp only [this]; simpa using hz + +/-- `scaleZ c p` is canonical (every lane is `feOfZMod _`). -/ +private theorem canonArr_scaleZ (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + CanonArr (scaleZ c p) := by + intro j hj + unfold scaleZ + rw [mkN_map_lane' (fun k => feOfZMod (c * zmodOfFE (p.val[k]!))) j hj _] + exact canon_feOfZMod _ + +/-- `scaleZ c p` is the per-lane `c`-scale of `p`. -/ +private theorem scaledArr_scaleZ (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + ScaledArr c (scaleZ c p) p := fun j hj => scaleZ_lane c p j hj + +/-- Two canonical 256-arrays that are both the per-lane `c`-scale of the same `p` + are equal. -/ +private theorem eq_of_scaledArr_canon (c : ZMod 3329) + (a b p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hca : CanonArr a) (hcb : CanonArr b) + (hsa : ScaledArr c a p) (hsb : ScaledArr c b p) : a = b := by + apply Subtype.ext + apply List.ext_getElem + · rw [Aeneas.Std.Array.length_eq a, Aeneas.Std.Array.length_eq b] + · intro j hj1 _hj2 + have hj : j < 256 := by rw [Aeneas.Std.Array.length_eq a] at hj1; simpa using hj1 + have hca' := hca j hj + have hcb' := hcb j hj + have hzeq : zmodOfFE (a.val[j]!) = zmodOfFE (b.val[j]!) := by + rw [hsa j hj, hsb j hj] + have ha := feOfZMod_zmodOfFE_of_canon (a.val[j]!) hca' + have hb := feOfZMod_zmodOfFE_of_canon (b.val[j]!) hcb' + have : a.val[j]! = b.val[j]! := by rw [← ha, ← hb, hzeq] + have haj : a.val[j]! = a.val[j] := getElem!_pos a.val j (by rw [Aeneas.Std.Array.length_eq a]; exact hj) + have hbj : b.val[j]! = b.val[j] := getElem!_pos b.val j (by rw [Aeneas.Std.Array.length_eq b]; exact hj) + rw [← haj, ← hbj]; exact this + +theorem invert_ntt_montgomery_pure_scaleZ (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hp : ∀ j : Nat, j < 256 → + libcrux_iot_ml_kem.Spec.Pure.Canonical (p.val[j]!)) : + Spec.invert_ntt_montgomery_pure (scaleZ c p) + = scaleZ c (Spec.invert_ntt_montgomery_pure p) := by + -- divisibility / positivity facts for layers 4-7 (step = 2^layer/16) + have hl4 : (0 < (1 <<< (4#usize).val) / 16) ∧ ((2 * ((1 <<< (4#usize).val) / 16)) ∣ 16) := by + constructor <;> decide + have hl5 : (0 < (1 <<< (5#usize).val) / 16) ∧ ((2 * ((1 <<< (5#usize).val) / 16)) ∣ 16) := by + constructor <;> decide + have hl6 : (0 < (1 <<< (6#usize).val) / 16) ∧ ((2 * ((1 <<< (6#usize).val) / 16)) ∣ 16) := by + constructor <;> decide + have hl7 : (0 < (1 <<< (7#usize).val) / 16) ∧ ((2 * ((1 <<< (7#usize).val) / 16)) ∣ 16) := by + constructor <;> decide + -- canonicity of the scaleZ side input + have hpq : CanonArr (scaleZ c p) := canonArr_scaleZ c p + have hpp : CanonArr p := hp + have hs0 : ScaledArr c (scaleZ c p) p := scaledArr_scaleZ c p + -- Unfold the 7-layer composition on both sides simultaneously. + unfold Spec.invert_ntt_montgomery_pure + -- q-side intermediates (scaled input), p-side intermediates (plain input). + -- Thread canon + scaled through layers 1,2,3,4,5,6,7. + have c1q := invert_ntt_layer_1_canon (scaleZ c p) 128#usize hpq + have c1p := invert_ntt_layer_1_canon p 128#usize hpp + have s1 := invert_ntt_layer_1_scaled c (scaleZ c p) p 128#usize hpq hpp hs0 + have c2q := invert_ntt_layer_2_canon _ 64#usize c1q + have c2p := invert_ntt_layer_2_canon _ 64#usize c1p + have s2 := invert_ntt_layer_2_scaled c _ _ 64#usize c1q c1p s1 + have c3q := invert_ntt_layer_3_canon _ 32#usize c2q + have c3p := invert_ntt_layer_3_canon _ 32#usize c2p + have s3 := invert_ntt_layer_3_scaled c _ _ 32#usize c2q c2p s2 + have c4q := invert_ntt_layer_4_plus_canon _ 16#usize 4#usize c3q + have c4p := invert_ntt_layer_4_plus_canon _ 16#usize 4#usize c3p + have s4 := invert_ntt_layer_4_plus_scaled c _ _ 16#usize 4#usize hl4.1 hl4.2 c3q c3p s3 + have c5q := invert_ntt_layer_4_plus_canon _ 8#usize 5#usize c4q + have c5p := invert_ntt_layer_4_plus_canon _ 8#usize 5#usize c4p + have s5 := invert_ntt_layer_4_plus_scaled c _ _ 8#usize 5#usize hl5.1 hl5.2 c4q c4p s4 + have c6q := invert_ntt_layer_4_plus_canon _ 4#usize 6#usize c5q + have c6p := invert_ntt_layer_4_plus_canon _ 4#usize 6#usize c5p + have s6 := invert_ntt_layer_4_plus_scaled c _ _ 4#usize 6#usize hl6.1 hl6.2 c5q c5p s5 + have c7q := invert_ntt_layer_4_plus_canon _ 2#usize 7#usize c6q + have c7p := invert_ntt_layer_4_plus_canon _ 2#usize 7#usize c6p + have s7 := invert_ntt_layer_4_plus_scaled c _ _ 2#usize 7#usize hl7.1 hl7.2 c6q c6p s6 + -- The RHS `scaleZ c (invert p)` is also `c`-scaled vs `invert p`, and canonical. + exact eq_of_scaledArr_canon c _ _ _ c7q (canonArr_scaleZ c _) s7 (scaledArr_scaleZ c _) + +end InvertScaleZ + +/-! ## C — hacspec `ntt_inverse` ≡ `scaleZ 3303 ∘ invert_ntt_montgomery_pure`. + + `3303 = INVERSE_OF_128 = 512·169`. Decomposed into: + + * **C1** — `ntt_inverse p = scaleZ 3303 (ntt_inverse_butterflies p)`: the + `reduce_polynomial` createi wrapper multiplying every lane by + `INVERSE_OF_128 = 3303`. + * **C2** — `ntt_inverse_butterflies p = .ok (invert_ntt_montgomery_pure p)`: + the hacspec 7-layer plain-field Gentleman–Sande inverse equals the + Mont-domain pure 7-layer inverse (the Montgomery `R`-factors net to 1). + + Canonical precondition (same mechanism as `invert_ntt_montgomery_pure_scaleZ`: + `sub_pure` saturates on non-canonical lanes). Discharged at the L7.4 call + site: C is applied at `p = multiply_vectors …`, whose lanes are + `add_pure`/`mul_pure` results (canonical). -/ +section InvertReduceC1 + +open libcrux_iot_ml_kem.Spec.Pure (Canonical) +/-- The `INVERSE_OF_128 = FieldElement.new 3303` FE has `zmodOfFE = 3303`. -/ +private theorem zmodOfFE_inverse_of_128 : + zmodOfFE (⟨3303#u16⟩ : hacspec_ml_kem.parameters.FieldElement) = (3303 : ZMod 3329) := by + unfold zmodOfFE + show ((3303#u16 : Std.U16).val : ZMod 3329) = 3303 + norm_num + +set_option maxHeartbeats 1000000 in +/-- `reduce_polynomial a` reduces to the per-lane `mul_pure a[k] 3303` array. -/ +private theorem reduce_polynomial_eq_ok + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + hacspec_ml_kem.invert_ntt.reduce_polynomial a + = .ok ⟨(List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (a.val[k]!) ⟨3303#u16⟩), + by simp [List.length_map, List.length_range]⟩ := by + set f : Nat → hacspec_ml_kem.parameters.FieldElement := + fun k => libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (a.val[k]!) ⟨3303#u16⟩ with hf_def + have hpure : ∀ k : Nat, k < (256#usize : Std.Usize).val → + (hacspec_ml_kem.invert_ntt.reduce_polynomial.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + : CoreModels.core.ops.function.Fn _ _ _).FnMutInst.call_mut + a ⟨BitVec.ofNat _ k⟩ + = .ok (f k, a) := by + intro k hk + have hk' : k < 256 := hk + show hacspec_ml_kem.invert_ntt.reduce_polynomial.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + a ⟨BitVec.ofNat _ k⟩ = .ok (f k, a) + unfold hacspec_ml_kem.invert_ntt.reduce_polynomial.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + unfold hacspec_ml_kem.invert_ntt.reduce_polynomial.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement.call + have hk_us : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + have : k < 2^System.Platform.numBits := by + have hbits : 2^16 ≤ 2^System.Platform.numBits := + Nat.pow_le_pow_right (by decide) (by + cases System.Platform.numBits_eq with + | inl h => rw [h]; decide + | inr h => rw [h]; decide) + omega + exact this + have ha_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < a.length := by + rw [hk_us]; show k < a.val.length + rw [a.property]; exact hk + have h_a_idx : + Std.Array.index_usize a (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (a.val[k]!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a + (⟨BitVec.ofNat _ k⟩ : Std.Usize) ha_len + rw [hk_us] at this; exact this + have h_inv : hacspec_ml_kem.invert_ntt.INVERSE_OF_128 + = .ok (⟨3303#u16⟩ : hacspec_ml_kem.parameters.FieldElement) := by + unfold hacspec_ml_kem.invert_ntt.INVERSE_OF_128 + hacspec_ml_kem.parameters.FieldElement.new + rfl + have h_mul := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_eq_ok + (a.val[k]!) (⟨3303#u16⟩ : hacspec_ml_kem.parameters.FieldElement) + change (do + let fe ← (do + let fe ← Std.Array.index_usize a ⟨BitVec.ofNat _ k⟩ + let fe1 ← hacspec_ml_kem.invert_ntt.INVERSE_OF_128 + hacspec_ml_kem.parameters.FieldElement.mul fe fe1) + Result.ok (fe, a)) = Result.ok (f k, a) + rw [h_a_idx]; simp only [bind_tc_ok] + rw [h_inv]; simp only [bind_tc_ok] + rw [h_mul]; simp only [bind_tc_ok, hf_def] + unfold hacspec_ml_kem.invert_ntt.reduce_polynomial + exact libcrux_iot_ml_kem.Util.CreateI.createi_pure_eq 256#usize + hacspec_ml_kem.invert_ntt.reduce_polynomial.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + a f hpure + +/-- **C1.** Given the butterflies output `q`, `ntt_inverse p = scaleZ 3303 q` + (the `reduce_polynomial` wrapper multiplies every lane by + `INVERSE_OF_128 = 3303`). Requires `q` canonical (so `scaleZ 3303 q` matches + the `mul_pure`-reduced lanes; the reduced lanes are canonical by `mul_pure`). -/ +private theorem ntt_inverse_reduce_eq + (p q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (_hq : ∀ j : Nat, j < 256 → Canonical (q.val[j]!)) + (hbut : hacspec_ml_kem.invert_ntt.ntt_inverse_butterflies p = .ok q) : + hacspec_ml_kem.invert_ntt.ntt_inverse p = .ok (scaleZ 3303 q) := by + unfold hacspec_ml_kem.invert_ntt.ntt_inverse + rw [hbut]; simp only [bind_tc_ok] + rw [reduce_polynomial_eq_ok q] + congr 1 + set L : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + ⟨(List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (q.val[k]!) ⟨3303#u16⟩), + by simp [List.length_map, List.length_range]⟩ with hL_def + have hL_lane : ∀ j : Nat, j < 256 → + L.val[j]! = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (q.val[j]!) ⟨3303#u16⟩ := by + intro j hj + show ((List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (q.val[k]!) ⟨3303#u16⟩))[j]! = _ + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + -- Both `L` and `scaleZ 3303 q` are the per-lane `3303`-scale of `q`; finish via + -- `eq_of_scaledArr_canon` (both canonical, both `ScaledArr 3303 · q`). + refine eq_of_scaledArr_canon 3303 L (scaleZ 3303 q) q ?_ (canonArr_scaleZ 3303 q) ?_ + (scaledArr_scaleZ 3303 q) + · -- L lanes canonical + intro j hj + rw [hL_lane j hj] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + · -- L is `ScaledArr 3303 L q` + intro j hj + rw [hL_lane j hj, zmodOfFE_mul_pure, zmodOfFE_inverse_of_128] + ring + +end InvertReduceC1 + +/-! ## C2 — hacspec `ntt_inverse_butterflies` ≡ `Spec.invert_ntt_montgomery_pure` + (factor 1). + + Closed by `ntt_inverse_butterflies_eq_invert_pure`: per-layer flat↔chunk + match (`ntt_inverse_layer` createi reduction + flat-lane + `i ↔ chunk (i/16, i%16)` correspondence + the cross-chunk layers 4-7), then + composing 7 layers; lemma C (`ntt_inverse_eq_scaleZ_invert_pure`) is C1 ∘ C2. + + `zetas_bridge_zmod` below is the gating arithmetic shared by every layer: + the hacspec plain-domain zeta table `ntt.ZETAS[i]` matches the Mont-domain + `Spec.zeta_at i` in `ZMod 3329` (`ZETAS[i] ≡ ZETAS_TIMES_MONTGOMERY_R[i]·R⁻¹`, + `R⁻¹ = 169`); proven by `interval_cases … <;> rfl` after unfolding both + tables. -/ +section InvertButterfliesC2 + +open libcrux_iot_ml_kem.Spec.Pure (Canonical) +/-- The hacspec `ntt.ZETAS` table, unwrapped from `Result`. `ntt.ZETAS` is a + pure `do`-chain of `FieldElement.new` (all `.ok`), so this is `.ok`-total; + the fallback branch is unreachable. -/ +private noncomputable def zetasArr : + Std.Array hacspec_ml_kem.parameters.FieldElement 128#usize := + match hacspec_ml_kem.ntt.ZETAS with + | .ok a => a + | _ => Std.Array.make 128#usize (List.replicate 128 ⟨0#u16⟩) (by simp) + +set_option maxRecDepth 20000 in +private theorem ntt_zetas_eq_ok : hacspec_ml_kem.ntt.ZETAS = .ok zetasArr := by + unfold zetasArr + unfold hacspec_ml_kem.ntt.ZETAS + rfl + +set_option maxRecDepth 20000 in +/-- **C2 gating arithmetic.** Per-entry equality (in `ZMod 3329`) of the hacspec + plain-domain zeta table and the Mont-domain `Spec.zeta_at`. Proven over all + 128 entries by case split + `rfl` (each entry: `ZETAS[i] = ZTMR[i]·169 mod q`). -/ +private theorem zetas_bridge_zmod (i : Nat) (hi : i < 128) : + zmodOfFE (zetasArr.val[i]!) = zmodOfFE (Spec.zeta_at i) := by + unfold zetasArr Spec.zeta_at lift_fe_mont + rw [zmodOfFE_feOfZMod] + unfold zmodOfFE i16_to_spec_fe_mont + unfold hacspec_ml_kem.ntt.ZETAS + unfold hacspec_ml_kem.parameters.FieldElement.new + simp only [bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R + interval_cases i <;> rfl + +/-! ### C2 : flat per-lane reduction of `ntt_inverse_layer_n`. + + The monadic usize arithmetic in the hacspec body is discharged via + `*_ok'` helpers (Aeneas `*_spec` Triples → `.ok`-exists), and the + `inv_butterfly` is reduced to its `add_pure`/`mul_pure`-of-`sub_pure` + pure projection (`inv_butterfly_eq`). The per-lane closure value + (`layer_n_at_eq`) feeds `from_fn_pure_eq` to give the full layer's + explicit per-lane array (`ntt_inverse_layer_n_eq_ok`). -/ + +private theorem umul_ok' (a b : Std.Usize) (h : a.val * b.val ≤ Std.Usize.max) : + ∃ c : Std.Usize, (a * b : Result Std.Usize) = .ok c ∧ c.val = a.val * b.val := by + obtain ⟨v, hv, hpv⟩ := Aeneas.Std.WP.spec_imp_exists (Std.Usize.mul_spec h) + exact ⟨v, hv, by simpa using hpv⟩ + +private theorem uadd_ok' (a b : Std.Usize) (h : a.val + b.val ≤ Std.Usize.max) : + ∃ c : Std.Usize, (a + b : Result Std.Usize) = .ok c ∧ c.val = a.val + b.val := by + obtain ⟨v, hv, hpv⟩ := Aeneas.Std.WP.spec_imp_exists (Std.Usize.add_spec h) + exact ⟨v, hv, by simpa using hpv⟩ + +private theorem usub_ok' (a b : Std.Usize) (h : b.val ≤ a.val) : + ∃ c : Std.Usize, (a - b : Result Std.Usize) = .ok c ∧ c.val = a.val - b.val := by + obtain ⟨v, hv, hpv⟩ := Aeneas.Std.WP.spec_imp_exists (Std.Usize.sub_spec h) + refine ⟨v, hv, ?_⟩ + have := hpv.1; simpa using this + +private theorem udiv_ok' (a b : Std.Usize) (h : b.val ≠ 0) : + ∃ c : Std.Usize, (a / b : Result Std.Usize) = .ok c ∧ c.val = a.val / b.val := by + obtain ⟨v, hv, hpv⟩ := Aeneas.Std.WP.spec_imp_exists (Std.Usize.div_spec a h) + exact ⟨v, hv, by simpa using hpv⟩ + +private theorem umod_ok' (a b : Std.Usize) (h : b.val ≠ 0) : + ∃ c : Std.Usize, (a % b : Result Std.Usize) = .ok c ∧ c.val = a.val % b.val := by + obtain ⟨v, hv, hpv⟩ := Aeneas.Std.WP.spec_imp_exists (Std.Usize.rem_spec a h) + exact ⟨v, hv, by simpa using hpv⟩ + +/-- `inv_butterfly z a b` pure projection: `(add a b, mul z (sub b a))` + (requires `a`, `b` canonical for the `sub`). -/ +private theorem inv_butterfly_eq (z a b : hacspec_ml_kem.parameters.FieldElement) + (ha : Canonical a) (hb : Canonical b) : + hacspec_ml_kem.invert_ntt.inv_butterfly z a b + = .ok (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b, + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure z + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure b a)) := by + unfold hacspec_ml_kem.invert_ntt.inv_butterfly + rw [libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_eq_ok a b] + simp only [bind_tc_ok] + rw [libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_eq_ok b a hb ha] + simp only [bind_tc_ok] + rw [libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_eq_ok z _] + simp only [bind_tc_ok] + +set_option maxRecDepth 8000 in +/-- Per-lane value of `ntt_inverse_layer_n_at p len s i`: the flat + Gentleman–Sande butterfly. `group = i/(2·len)`, `idx = i%(2·len)`; + a-side (`idx < len`) = `add p[i] p[i+len]`, b-side = + `mul s[group] (sub p[i] p[i−len])`. -/ +private theorem layer_n_at_eq + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (len : Std.Usize) (s : Slice hacspec_ml_kem.parameters.FieldElement) + (i : Std.Usize) + (hlen : 0 < len.val) (h2len : 2 * len.val ≤ Std.Usize.max) + (hi : i.val < 256) (hil : i.val + len.val ≤ Std.Usize.max) + (hcanon : ∀ j : Nat, j < 256 → Canonical (p.val[j]!)) + (hsg : i.val / (2 * len.val) < s.val.length) + (hapart : i.val % (2*len.val) < len.val → i.val + len.val < 256) + (hbpart : ¬ (i.val % (2*len.val) < len.val) → len.val ≤ i.val) : + hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n_at p len s i + = .ok (if i.val % (2*len.val) < len.val then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (p.val[i.val]!) (p.val[i.val + len.val]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (s.val[i.val / (2*len.val)]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (p.val[i.val]!) (p.val[i.val - len.val]!))) := by + unfold hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n_at + obtain ⟨i1, hi1, hi1v⟩ := umul_ok' 2#usize len (by simpa using h2len) + rw [hi1]; simp only [bind_tc_ok] + have hi1ne : i1.val ≠ 0 := by rw [hi1v]; simp; omega + obtain ⟨grp, hgrp, hgrpv⟩ := udiv_ok' i i1 hi1ne + rw [hgrp]; simp only [bind_tc_ok] + obtain ⟨idx, hidx, hidxv⟩ := umod_ok' i i1 hi1ne + rw [hidx]; simp only [bind_tc_ok] + have hidxlen : (idx.val < len.val) = (i.val % (2*len.val) < len.val) := by + rw [hidxv, hi1v]; simp + have hdec : (idx < len) = (idx.val < len.val) := by + simp [Std.UScalar.lt_equiv] + by_cases hbr : idx.val < len.val + · rw [if_pos (by rw [hdec]; exact hbr : idx < len)] + have hbr' : i.val % (2*len.val) < len.val := by rw [← hidxlen]; exact hbr + rw [if_pos hbr'] + rw [libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq s grp (by rw [hgrpv, hi1v]; exact hsg)] + simp only [bind_tc_ok] + rw [libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq p i + (by show i.val < p.val.length; rw [p.property]; exact hi)] + simp only [bind_tc_ok] + obtain ⟨i2, hi2, hi2v⟩ := uadd_ok' i len (by simpa using hil) + rw [hi2]; simp only [bind_tc_ok] + have hi2lt : i2.val < 256 := by rw [hi2v]; exact hapart hbr' + rw [libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq p i2 + (by show i2.val < p.val.length; rw [p.property]; exact hi2lt)] + simp only [bind_tc_ok] + rw [inv_butterfly_eq _ _ _ (hcanon i.val hi) (hcanon i2.val hi2lt)] + simp only [bind_tc_ok, hi2v] + rfl + · rw [if_neg (by rw [hdec]; exact hbr : ¬ (idx < len))] + have hbr' : ¬ (i.val % (2*len.val) < len.val) := by rw [← hidxlen]; exact hbr + rw [if_neg hbr'] + rw [libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq s grp (by rw [hgrpv, hi1v]; exact hsg)] + simp only [bind_tc_ok] + obtain ⟨i2, hi2, hi2v⟩ := usub_ok' i len (hbpart hbr') + rw [hi2]; simp only [bind_tc_ok] + have hi2lt : i2.val < 256 := by rw [hi2v]; omega + rw [libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq p i2 + (by show i2.val < p.val.length; rw [p.property]; exact hi2lt)] + simp only [bind_tc_ok] + rw [libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq p i + (by show i.val < p.val.length; rw [p.property]; exact hi)] + simp only [bind_tc_ok] + rw [inv_butterfly_eq _ _ _ (hcanon i2.val hi2lt) (hcanon i.val hi)] + simp only [bind_tc_ok, hgrpv, hi1v, hi2v] + norm_num + +set_option maxHeartbeats 1000000 in +/-- Full per-layer reduction of `ntt_inverse_layer_n p len s` to an explicit + per-lane array (createi → `createi_pure_eq` over `layer_n_at_eq`). The + `hslen`/`hpart` hypotheses package the slice-bound + partner-in-range facts + for all 256 lanes (discharged per concrete layer downstream). -/ +private theorem ntt_inverse_layer_n_eq_ok + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (len : Std.Usize) (s : Slice hacspec_ml_kem.parameters.FieldElement) + (hlen : 0 < len.val) (h2len : 2 * len.val ≤ Std.Usize.max) + (hlen256 : len.val ≤ 128) + (hcanon : ∀ j : Nat, j < 256 → Canonical (p.val[j]!)) + (hslen : ∀ i : Nat, i < 256 → i / (2 * len.val) < s.val.length) + (hpart : ∀ i : Nat, i < 256 → (i % (2*len.val) < len.val → i + len.val < 256) + ∧ (¬ (i % (2*len.val) < len.val) → len.val ≤ i)) : + hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n p len s + = .ok ⟨(List.range 256).map (fun i => + if i % (2*len.val) < len.val then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (p.val[i]!) (p.val[i + len.val]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (s.val[i / (2*len.val)]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (p.val[i]!) (p.val[i - len.val]!))), + by simp [List.length_map, List.length_range]⟩ := by + set f : Nat → hacspec_ml_kem.parameters.FieldElement := + fun i => + if i % (2*len.val) < len.val then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (p.val[i]!) (p.val[i + len.val]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (s.val[i / (2*len.val)]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (p.val[i]!) (p.val[i - len.val]!)) with hf_def + have hpure : ∀ k : Nat, k < (256#usize : Std.Usize).val → + (hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + 256#usize : CoreModels.core.ops.function.Fn _ _ _).FnMutInst.call_mut + (p, len, s) ⟨BitVec.ofNat _ k⟩ + = .ok (f k, (p, len, s)) := by + intro k hk + have hk' : k < 256 := hk + show hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + (p, len, s) ⟨BitVec.ofNat _ k⟩ = .ok (f k, (p, len, s)) + unfold hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + unfold hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement.call + have hk_us : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + rw [Std.UScalarTy.Usize_numBits_eq] + have hbits : 2^16 ≤ 2^System.Platform.numBits := + Nat.pow_le_pow_right (by decide) (by + cases System.Platform.numBits_eq with + | inl h => rw [h]; decide + | inr h => rw [h]; decide) + omega + have hmaxb : (384:Nat) ≤ Std.Usize.max := by scalar_tac + have hilk : k + len.val ≤ Std.Usize.max := by omega + show (do + let fe ← hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n_at p len s ⟨BitVec.ofNat _ k⟩ + Result.ok (fe, (p, len, s))) = _ + have hlat := layer_n_at_eq p len s ⟨BitVec.ofNat _ k⟩ hlen h2len (by rw [hk_us]; exact hk') + (by rw [hk_us]; exact hilk) hcanon + (by rw [hk_us]; exact hslen k hk') (by rw [hk_us]; exact (hpart k hk').1) + (by rw [hk_us]; exact (hpart k hk').2) + rw [hlat] + simp only [bind_tc_ok, hk_us, hf_def] + rfl + unfold hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n + exact libcrux_iot_ml_kem.Util.CreateI.createi_pure_eq 256#usize + (hacspec_ml_kem.invert_ntt.ntt_inverse_layer_n.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement 256#usize) + (p, len, s) f hpure + +/-! ### C2 : `ntt_inverse_layer` (zetas createi + range-slice) reduction. + + `ntt_inverse_layer p layer` builds `zetas = createi 128 (closure) groups` + (per-index `g`, `ntt_inverse_layer_zeta groups g`), slices `s = zetas[0:groups]`, + and feeds `s` to `ntt_inverse_layer_n p len s`. We reduce the whole thing to an + explicit flat per-lane array using the global `zetasArr` table: + b-side lane `i` reads `zetasArr[2·groups − 1 − i/(2·len)]`. -/ + +/-- `BitVec.ofNat _ k` round-trips through `Usize.val` when `k < 256` (local copy). -/ +private theorem usize_ofNat_val_eq_self_of_lt_256' (k : Nat) (h : k < 256) : + (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat System.Platform.numBits k).toNat = k + rw [BitVec.toNat_ofNat] + apply Nat.mod_eq_of_lt + have h_max : k ≤ Std.Usize.max := by scalar_tac + have h_max_def : Std.Usize.max + 1 = 2 ^ System.Platform.numBits := by scalar_tac + omega + +set_option maxRecDepth 8000 in +/-- The `zetas` createi inside `ntt_inverse_layer` reduces to an explicit array: + lane `g` is `ntt_inverse_layer_zeta groups ⟨g⟩` (i.e. `zetasArr[2·groups−1−g]` + when `g < groups`). Requires `groups ≤ 64` (true for layers 1-7, `len ≥ 2`). -/ +private theorem ntt_inverse_layer_zetas_eq_ok (groups : Std.Usize) + (hgroups : groups.val ≤ 64) : + (hacspec_ml_kem.parameters.createi 128#usize + hacspec_ml_kem.invert_ntt.ntt_inverse_layer.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + groups) + = .ok ⟨(List.range 128).map (fun g => + if g < groups.val then zetasArr.val[2 * groups.val - 1 - g]! + else (⟨0#u16⟩ : hacspec_ml_kem.parameters.FieldElement)), + by simp [List.length_map, List.length_range]⟩ := by + set f : Nat → hacspec_ml_kem.parameters.FieldElement := + fun g => if g < groups.val then zetasArr.val[2 * groups.val - 1 - g]! + else (⟨0#u16⟩ : hacspec_ml_kem.parameters.FieldElement) with hf_def + have hpure : ∀ g : Nat, g < (128#usize : Std.Usize).val → + (hacspec_ml_kem.invert_ntt.ntt_inverse_layer.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + : CoreModels.core.ops.function.Fn _ _ _).FnMutInst.call_mut + groups ⟨BitVec.ofNat _ g⟩ + = .ok (f g, groups) := by + intro g hg + have hg' : g < 128 := hg + show hacspec_ml_kem.invert_ntt.ntt_inverse_layer.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + groups ⟨BitVec.ofNat _ g⟩ = .ok (f g, groups) + unfold hacspec_ml_kem.invert_ntt.ntt_inverse_layer.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + unfold hacspec_ml_kem.invert_ntt.ntt_inverse_layer.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement.call + have hg_us : (⟨BitVec.ofNat _ g⟩ : Std.Usize).val = g := + usize_ofNat_val_eq_self_of_lt_256' g (by omega) + show (do + let fe ← hacspec_ml_kem.invert_ntt.ntt_inverse_layer_zeta groups ⟨BitVec.ofNat _ g⟩ + Result.ok (fe, groups)) = _ + unfold hacspec_ml_kem.invert_ntt.ntt_inverse_layer_zeta + by_cases hbr : (⟨BitVec.ofNat _ g⟩ : Std.Usize) < groups + · have hbr' : g < groups.val := by + have := (Std.UScalar.lt_equiv (⟨BitVec.ofNat _ g⟩ : Std.Usize) groups).mp hbr + rw [hg_us] at this; exact this + rw [if_pos hbr] + rw [ntt_zetas_eq_ok]; simp only [bind_tc_ok] + have h2us : (2#usize : Std.Usize).val = 2 := by scalar_tac + have h1us : (1#usize : Std.Usize).val = 1 := by scalar_tac + obtain ⟨i1, hi1, hi1v⟩ := umul_ok' 2#usize groups (by + show (2#usize : Std.Usize).val * groups.val ≤ Std.Usize.max + have hm : (128:Nat) ≤ Std.Usize.max := by scalar_tac + rw [h2us]; omega) + rw [hi1]; simp only [bind_tc_ok] + obtain ⟨i2, hi2, hi2v⟩ := usub_ok' i1 1#usize (by + show (1#usize : Std.Usize).val ≤ i1.val; rw [hi1v, h1us, h2us]; omega) + rw [hi2]; simp only [bind_tc_ok] + obtain ⟨i3, hi3, hi3v⟩ := usub_ok' i2 ⟨BitVec.ofNat _ g⟩ (by + show (⟨BitVec.ofNat _ g⟩ : Std.Usize).val ≤ i2.val + rw [hi2v, hi1v, h1us, h2us, hg_us]; omega) + rw [hi3]; simp only [bind_tc_ok] + have h128us : (128#usize : Std.Usize).val = 128 := by scalar_tac + have hi3lt : i3.val < zetasArr.val.length := by + rw [zetasArr.property, h128us, hi3v, hi2v, hi1v, h1us, h2us, hg_us]; omega + rw [libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq zetasArr i3 hi3lt] + simp only [bind_tc_ok, hf_def, if_pos hbr', hi3v, hi2v, hi1v, h1us, h2us, hg_us] + · have hbr' : ¬ (g < groups.val) := by + intro hc + exact hbr ((Std.UScalar.lt_equiv (⟨BitVec.ofNat _ g⟩ : Std.Usize) groups).mpr (by + rw [hg_us]; exact hc)) + rw [if_neg hbr] + show (do + let fe ← hacspec_ml_kem.parameters.FieldElement.new 0#u16 + Result.ok (fe, groups)) = _ + unfold hacspec_ml_kem.parameters.FieldElement.new + simp only [bind_tc_ok, hf_def, if_neg hbr'] + have h_from_fn := + libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq + (T := hacspec_ml_kem.parameters.FieldElement) + (F := hacspec_ml_kem.invert_ntt.ntt_inverse_layer.closure) + (N := 128#usize) + (inst := hacspec_ml_kem.invert_ntt.ntt_inverse_layer.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement) + (c := groups) + (f := f) + hpure + unfold hacspec_ml_kem.parameters.createi + show core.array.from_fn 128#usize _ groups = _ + exact h_from_fn + +/-- The hacspec slice-by-range extraction `zetas[0..groups]` reduces to + `List.slice 0 groups zetas.val` (local copy of `slice_zetas_succeeds`). -/ +private theorem slice_zetas0_succeeds + (zs : Std.Array hacspec_ml_kem.parameters.FieldElement 128#usize) + (groups : Std.Usize) (hgroups : groups.val ≤ 128) : + core.Array.Insts.CoreOpsIndexIndex.index + (core.Slice.Insts.CoreOpsIndexIndex + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + hacspec_ml_kem.parameters.FieldElement)) zs + { start := 0#usize, «end» := groups } + = .ok (⟨List.slice 0 groups.val zs.val, by + unfold List.slice + have h : zs.val.length = 128 := zs.property + simp only [List.drop_zero, List.length_take, h] + scalar_tac⟩ : + Aeneas.Std.Slice hacspec_ml_kem.parameters.FieldElement) := by + unfold core.Array.Insts.CoreOpsIndexIndex.index + core.slice.index.Slice.index + core.Slice.Insts.CoreOpsIndexIndex + core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + show core.slice.index.SliceIndexRangeUsizeSlice.index + (core.cmRangeUsizeToAeneas _) zs.to_slice = _ + unfold core.slice.index.SliceIndexRangeUsizeSlice.index + core.cmRangeUsizeToAeneas + have h_alen : zs.val.length = 128 := zs.property + have h_cond : (0#usize : Std.Usize) ≤ groups ∧ + groups.val ≤ zs.to_slice.val.length := by + refine ⟨by scalar_tac, by + show groups.val ≤ zs.val.length; omega⟩ + rw [if_pos h_cond] + rfl + +/-- `List.slice 0 b l [k]! = l[k]!` when `b ≤ l.length` and `k < b` + (local specialization of `slice_getElem_at`). -/ +private theorem slice0_getElem_at {α} [Inhabited α] + (l : List α) (b : Nat) (_h_le_b : b ≤ l.length) (k : Nat) (hk : k < b) : + (List.slice 0 b l)[k]! = l[k]! := by + unfold List.slice + rw [List.drop_zero, Nat.sub_zero] + have h_take_idx : (l.take b)[k]? = l[k]? := by + rw [List.getElem?_take, if_pos hk] + rw [List.getElem!_eq_getElem?_getD, List.getElem!_eq_getElem?_getD, h_take_idx] + +set_option maxRecDepth 8000 in +set_option maxHeartbeats 1000000 in +/-- **Full reduction of `ntt_inverse_layer p layer`** to an explicit flat per-lane + array. b-side lane `i` reads `zetasArr[2·groups − 1 − i/(2·len)]` (plain domain), + a-side `add p[i] p[i+len]`. Parametrized by the resolved `len`/`groups` values. -/ +private theorem ntt_inverse_layer_eq_ok + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (layer len groups : Std.Usize) + (hlen_def : (1#usize <<< layer : Result Std.Usize) = .ok len) + (hgroups_def : (128#usize / len : Result Std.Usize) = .ok groups) + (hlenpos : 0 < len.val) (hlen2 : 2 ≤ len.val) (hlen128 : len.val ≤ 128) + (hgroups : groups.val = 128 / len.val) + (hpart_slice : ∀ i : Nat, i < 256 → i / (2 * len.val) < groups.val) + (hpart : ∀ i : Nat, i < 256 → (i % (2*len.val) < len.val → i + len.val < 256) + ∧ (¬ (i % (2*len.val) < len.val) → len.val ≤ i)) + (hcanon : ∀ j : Nat, j < 256 → Canonical (p.val[j]!)) : + hacspec_ml_kem.invert_ntt.ntt_inverse_layer p layer + = .ok ⟨(List.range 256).map (fun i => + if i % (2*len.val) < len.val then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (p.val[i]!) (p.val[i + len.val]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (zetasArr.val[2 * groups.val - 1 - i / (2*len.val)]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (p.val[i]!) (p.val[i - len.val]!))), + by simp [List.length_map, List.length_range]⟩ := by + unfold hacspec_ml_kem.invert_ntt.ntt_inverse_layer + rw [hlen_def]; simp only [bind_tc_ok] + rw [hgroups_def]; simp only [bind_tc_ok] + have hgroups64 : groups.val ≤ 64 := by + rw [hgroups] + calc 128 / len.val ≤ 128 / 2 := Nat.div_le_div_left hlen2 (by omega) + _ = 64 := by norm_num + rw [ntt_inverse_layer_zetas_eq_ok groups hgroups64]; simp only [bind_tc_ok] + set zs : Std.Array hacspec_ml_kem.parameters.FieldElement 128#usize := + ⟨(List.range 128).map (fun g => + if g < groups.val then zetasArr.val[2 * groups.val - 1 - g]! + else (⟨0#u16⟩ : hacspec_ml_kem.parameters.FieldElement)), + by simp [List.length_map, List.length_range]⟩ with hzs_def + rw [slice_zetas0_succeeds zs groups (by omega)]; simp only [bind_tc_ok] + set s : Slice hacspec_ml_kem.parameters.FieldElement := + ⟨List.slice 0 groups.val zs.val, by + unfold List.slice + have h : zs.val.length = 128 := zs.property + simp only [List.drop_zero, List.length_take, h]; scalar_tac⟩ with hs_def + have hslen_eq : s.val.length = groups.val := by + show (List.slice 0 groups.val zs.val).length = groups.val + unfold List.slice + have h : zs.val.length = 128 := zs.property + simp only [List.drop_zero, List.length_take, h]; omega + have h2lenmax : 2 * len.val ≤ Std.Usize.max := by + have : (256:Nat) ≤ Std.Usize.max := by scalar_tac + omega + -- s lane access: for g < groups, s[g]! = zs[g]! = zetasArr[2groups-1-g] + have hslane : ∀ g : Nat, g < groups.val → + s.val[g]! = zetasArr.val[2 * groups.val - 1 - g]! := by + intro g hg + show (List.slice 0 groups.val zs.val)[g]! = _ + rw [slice0_getElem_at zs.val groups.val (by + rw [zs.property]; have : (128#usize : Std.Usize).val = 128 := by scalar_tac + omega) g hg] + show ((List.range 128).map (fun g => + if g < groups.val then zetasArr.val[2 * groups.val - 1 - g]! + else (⟨0#u16⟩ : hacspec_ml_kem.parameters.FieldElement)))[g]! = _ + rw [getElem!_pos _ g (by simp [List.length_map, List.length_range]; omega)] + rw [List.getElem_map, List.getElem_range, if_pos hg] + rw [ntt_inverse_layer_n_eq_ok p len s hlenpos h2lenmax hlen128 hcanon + (fun i hi => by rw [hslen_eq]; exact hpart_slice i hi) hpart] + -- the flat array from the layer_n reduction uses `s[i/(2len)]!`; rewrite to zetasArr. + congr 1 + apply Subtype.ext + simp only [] + apply List.map_congr_left + intro i hi + have hi256 : i < 256 := by rw [List.mem_range] at hi; exact hi + by_cases hbr : i % (2*len.val) < len.val + · rw [if_pos hbr, if_pos hbr] + · rw [if_neg hbr, if_neg hbr] + rw [hslane (i / (2*len.val)) (hpart_slice i hi256)] + +/-! ### C2 : per-layer flat↔chunk match. + + For each layer N we prove `ntt_inverse_layer q N = .ok (Spec.invert_ntt_layer_..._pure + q zeta_i)` for canonical `q`, by reducing the LHS to the flat array + (`ntt_inverse_layer_eq_ok`) and the RHS lane to the same flat butterfly (in + `ZMod 3329`), then concluding via canonical determination (`eq_of_zmod_lane_canon'`). + The per-lane match uses commutativity of `add`/`mul` in `ZMod` and + `zetas_bridge_zmod` (zeta-index correspondence). -/ + +/-- Local copy of `eq_of_zmod_lane_canon` (the SubPolyScaleZ one is defined later). -/ +private theorem eq_of_zmod_lane_canon' + (u v : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hcu : ∀ j : Nat, j < 256 → Canonical (u.val[j]!)) + (hcv : ∀ j : Nat, j < 256 → Canonical (v.val[j]!)) + (hz : ∀ j : Nat, j < 256 → zmodOfFE (u.val[j]!) = zmodOfFE (v.val[j]!)) : + u = v := by + apply Subtype.ext + apply List.ext_getElem + · rw [Aeneas.Std.Array.length_eq u, Aeneas.Std.Array.length_eq v] + · intro j hj1 _hj2 + have hj : j < 256 := by rw [Aeneas.Std.Array.length_eq u] at hj1; simpa using hj1 + have heq : u.val[j]! = v.val[j]! := by + rw [← feOfZMod_zmodOfFE_of_canon (u.val[j]!) (hcu j hj), + ← feOfZMod_zmodOfFE_of_canon (v.val[j]!) (hcv j hj), hz j hj] + have huj : u.val[j]! = u.val[j] := + getElem!_pos u.val j (by rw [Aeneas.Std.Array.length_eq u]; exact hj) + have hvj : v.val[j]! = v.val[j] := + getElem!_pos v.val j (by rw [Aeneas.Std.Array.length_eq v]; exact hj) + rw [← huj, ← hvj]; exact heq + +set_option maxRecDepth 8000 in +set_option maxHeartbeats 2000000 in +/-- Lane `ℓ` of `chunk_inv_ntt_layer_1_step_pure a z0 z1 z2 z3`: each lane is + written by exactly one of the 8 disjoint butterfly steps; pairs + `(0,2)(1,3)(4,6)(5,7)(8,10)(9,11)(12,14)(13,15)` with zetas `z0 z0 z1 z1 z2 z2 z3 z3`. -/ +private theorem chunk_inv_ntt_layer_1_step_lane + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 z2 z3 : hacspec_ml_kem.parameters.FieldElement) (ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_inv_ntt_layer_1_step_pure a z0 z1 z2 z3).val[ℓ]! + = if ℓ % 4 < 2 then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (a.val[ℓ + 2]!) (a.val[ℓ]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[ℓ]!) (a.val[ℓ - 2]!)) + (if ℓ / 4 = 0 then z0 else if ℓ / 4 = 1 then z1 else if ℓ / 4 = 2 then z2 else z3) := by + unfold Spec.chunk_inv_ntt_layer_1_step_pure Spec.chunk_inv_ntt_step_pure + -- each lane is written once (disjoint pairs); resolve nested `.set` getElem. + interval_cases ℓ <;> + simp only [Aeneas.Std.Array.set_val_eq] <;> norm_num + +set_option maxRecDepth 8000 in +set_option maxHeartbeats 2000000 in +/-- Lane `ℓ` of `chunk_inv_ntt_layer_2_step_pure a z0 z1`: pairs + `(0,4)(1,5)(2,6)(3,7)(8,12)(9,13)(10,14)(11,15)` with zetas `z0×4, z1×4`. -/ +private theorem chunk_inv_ntt_layer_2_step_lane + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 : hacspec_ml_kem.parameters.FieldElement) (ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_inv_ntt_layer_2_step_pure a z0 z1).val[ℓ]! + = if ℓ % 8 < 4 then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (a.val[ℓ + 4]!) (a.val[ℓ]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[ℓ]!) (a.val[ℓ - 4]!)) + (if ℓ / 8 = 0 then z0 else z1) := by + unfold Spec.chunk_inv_ntt_layer_2_step_pure Spec.chunk_inv_ntt_step_pure + interval_cases ℓ <;> + simp only [Aeneas.Std.Array.set_val_eq] <;> norm_num + +set_option maxRecDepth 8000 in +set_option maxHeartbeats 2000000 in +/-- Lane `ℓ` of `chunk_inv_ntt_layer_3_step_pure a z`: pairs + `(0,8)(1,9)…(7,15)` with single zeta `z`. -/ +private theorem chunk_inv_ntt_layer_3_step_lane + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) (ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_inv_ntt_layer_3_step_pure a z).val[ℓ]! + = if ℓ < 8 then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (a.val[ℓ + 8]!) (a.val[ℓ]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[ℓ]!) (a.val[ℓ - 8]!)) z := by + unfold Spec.chunk_inv_ntt_layer_3_step_pure Spec.chunk_inv_ntt_step_pure + interval_cases ℓ <;> + simp only [Aeneas.Std.Array.set_val_eq] <;> norm_num + +/-! ### C2 (cont.): per-layer match `ntt_inverse_layer ≡ spec layer`. -/ + +/-- `(1#usize <<< n)` succeeds with value `2^n.val` (for `n.val < numBits`). -/ +private theorem shl_one_ok (n : Std.Usize) (hn : n.val < UScalarTy.Usize.numBits) : + ∃ len : Std.Usize, (1#usize <<< n : Result Std.Usize) = .ok len ∧ len.val = 2 ^ n.val := by + have h_one_shl_pow : ((1#usize : Std.Usize).val <<< n.val) < 2 ^ System.Platform.numBits := by + have h_one_eq : (1#usize : Std.Usize).val = 1 := rfl + rw [h_one_eq, Nat.shiftLeft_eq, Nat.one_mul] + have hnb : n.val < System.Platform.numBits := by + rwa [Std.UScalarTy.Usize_numBits_eq] at hn + rcases System.Platform.numBits_eq with h32 | h64 + · rw [h32]; rw [h32] at hnb; exact Nat.pow_lt_pow_right (by decide) hnb + · rw [h64]; rw [h64] at hnb; exact Nat.pow_lt_pow_right (by decide) hnb + have hT := Aeneas.Std.UScalar.ShiftLeft_spec (1#usize : Std.Usize) n + (Aeneas.Std.UScalar.size Aeneas.Std.UScalarTy.Usize) hn rfl + obtain ⟨z, h_eq, h_v_mod, _h_bv⟩ := Std.WP.spec_imp_exists hT + refine ⟨z, h_eq, ?_⟩ + have h_one_eq : (1#usize : Std.Usize).val = 1 := rfl + have h_size_eq : (Aeneas.Std.UScalar.size Aeneas.Std.UScalarTy.Usize) = 2 ^ System.Platform.numBits := by + rw [Aeneas.Std.UScalar.size]; rw [Std.UScalarTy.Usize_numBits_eq] + rw [h_v_mod, h_one_eq, h_size_eq, Nat.shiftLeft_eq, Nat.one_mul, Nat.mod_eq_of_lt] + rw [h_one_eq, Nat.shiftLeft_eq, Nat.one_mul] at h_one_shl_pow + exact h_one_shl_pow + +private theorem numbits_ge (n : Nat) (hn : n ≤ 7) : n < UScalarTy.Usize.numBits := by + rw [Std.UScalarTy.Usize_numBits_eq] + rcases System.Platform.numBits_eq with h | h <;> (rw [h]; omega) + +/-- `128#usize / len` succeeds with value `128 / len.val` (for `len.val ≠ 0`). -/ +private theorem div128_ok (len : Std.Usize) (hlen : len.val ≠ 0) : + ∃ g : Std.Usize, (128#usize / len : Result Std.Usize) = .ok g ∧ g.val = 128 / len.val := by + obtain ⟨v, hv, hpv⟩ := Aeneas.Std.WP.spec_imp_exists (Std.Usize.div_spec 128#usize hlen) + exact ⟨v, hv, by simpa using hpv⟩ + +/-- Spec layer-1 lane in flat form: `(invert_ntt_layer_1_pure q 128).val[i]!` for `i<256`. -/ +private theorem spec_inv_layer_1_lane + (q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) (i : Nat) (hi : i < 256) : + (Spec.invert_ntt_layer_1_pure q 128#usize).val[i]! + = if i % 4 < 2 then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (q.val[i + 2]!) (q.val[i]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (q.val[i]!) (q.val[i - 2]!)) + (Spec.zeta_at (128 - 4 * (i / 16) - (i % 16) / 4 - 1)) := by + unfold Spec.invert_ntt_layer_1_pure + simp only [show (128#usize : Std.Usize).val = 128 from by rfl] + rw [flatten_chunk_map_lane (fun k => Spec.chunk_inv_ntt_layer_1_step_pure (Spec.chunk_at q k) + (Spec.zeta_at (128 - 4 * k - 1)) (Spec.zeta_at (128 - 4 * k - 2)) + (Spec.zeta_at (128 - 4 * k - 3)) (Spec.zeta_at (128 - 4 * k - 4))) i hi (by simp)] + rw [chunk_inv_ntt_layer_1_step_lane _ _ _ _ _ (i % 16) (Nat.mod_lt _ (by decide))] + have hk : i / 16 < 16 := by omega + have hmod4 : (i % 16) % 4 = i % 4 := by omega + have him : i % 16 < 16 := by omega + by_cases hbr : i % 4 < 2 + · rw [if_pos (by rw [hmod4]; exact hbr)] + have ha2 : (i % 16) + 2 < 16 := by + have : (i % 16) % 4 < 2 := by rw [hmod4]; exact hbr + omega + rw [if_pos hbr] + rw [chunk_at_lane' q (i / 16) (i % 16) (by omega), + chunk_at_lane' q (i / 16) ((i % 16) + 2) ha2] + congr 2 <;> omega + · rw [if_neg (by rw [hmod4]; exact hbr), if_neg hbr] + have hsub : 2 ≤ i % 16 := by omega + rw [chunk_at_lane' q (i / 16) (i % 16) (by omega), + chunk_at_lane' q (i / 16) ((i % 16) - 2) (by omega)] + have hidx1 : 16 * (i / 16) + (i % 16) = i := by omega + have hidx2 : 16 * (i / 16) + ((i % 16) - 2) = i - 2 := by omega + rw [hidx1, hidx2] + -- zeta selector: (i%16)/4 picks z_{(i%16)/4} = zeta_at (128 - 4k - ((i%16)/4) - 1) + have hzsel : (if (i % 16) / 4 = 0 then Spec.zeta_at (128 - 4 * (i / 16) - 1) + else if (i % 16) / 4 = 1 then Spec.zeta_at (128 - 4 * (i / 16) - 2) + else if (i % 16) / 4 = 2 then Spec.zeta_at (128 - 4 * (i / 16) - 3) + else Spec.zeta_at (128 - 4 * (i / 16) - 4)) + = Spec.zeta_at (128 - 4 * (i / 16) - (i % 16) / 4 - 1) := by + have hq : (i % 16) / 4 < 4 := by omega + interval_cases h : ((i % 16) / 4) <;> simp <;> congr 1 + rw [hzsel] + +/-- The explicit flat layer array (from `ntt_inverse_layer_eq_ok`) has canonical lanes. -/ +private theorem flat_layer_canon + (q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) (len groups : Std.Usize) + (j : Nat) (hj : j < 256) : + Canonical ((⟨(List.range 256).map (fun i => + if i % (2*len.val) < len.val then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (q.val[i]!) (q.val[i + len.val]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (zetasArr.val[2 * groups.val - 1 - i / (2*len.val)]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (q.val[i]!) (q.val[i - len.val]!))), + by simp [List.length_map, List.length_range]⟩ : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize).val[j]!) := by + show Canonical (((List.range 256).map (fun i => + if i % (2*len.val) < len.val then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (q.val[i]!) (q.val[i + len.val]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (zetasArr.val[2 * groups.val - 1 - i / (2*len.val)]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (q.val[i]!) (q.val[i - len.val]!))))[j]!) + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + by_cases hbr : j % (2*len.val) < len.val + · rw [if_pos hbr]; exact libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + · rw [if_neg hbr]; exact libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _ + +set_option maxHeartbeats 1000000 in +/-- **Layer-1 match:** `ntt_inverse_layer q 1 = .ok (invert_ntt_layer_1_pure q 128)`. -/ +private theorem ntt_inverse_layer_1_match + (q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hq : ∀ j : Nat, j < 256 → Canonical (q.val[j]!)) : + hacspec_ml_kem.invert_ntt.ntt_inverse_layer q 1#usize + = .ok (Spec.invert_ntt_layer_1_pure q 128#usize) := by + obtain ⟨len, hlen_def, hlenv⟩ := shl_one_ok 1#usize (numbits_ge 1 (by omega)) + have hlenv2 : len.val = 2 := by rw [hlenv]; decide + obtain ⟨groups, hgroups_def, hgroupsv⟩ := div128_ok len (by omega) + have hgroupsv2 : groups.val = 64 := by rw [hgroupsv, hlenv2] + rw [ntt_inverse_layer_eq_ok q 1#usize len groups hlen_def hgroups_def + (by omega) (by omega) (by omega) (by rw [hgroupsv, hlenv2]) + (fun i hi => by rw [hlenv2, hgroupsv2]; omega) + (fun i hi => ⟨fun hc => by rw [hlenv2] at *; omega, fun hc => by rw [hlenv2] at *; omega⟩) hq] + congr 1 + refine eq_of_zmod_lane_canon' _ _ (flat_layer_canon q len groups) + (invert_ntt_layer_1_canon q 128#usize hq) ?_ + intro j hj + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj]), + List.getElem_map, List.getElem_range] + rw [spec_inv_layer_1_lane q j hj] + rw [hlenv2, hgroupsv2] + by_cases hbr : j % 4 < 2 + · rw [show (2 * 2) = 4 from by norm_num, if_pos hbr, if_pos hbr] + rw [zmodOfFE_add_pure, zmodOfFE_add_pure]; ring + · rw [show (2 * 2) = 4 from by norm_num, if_neg hbr, if_neg hbr] + have hsub : 2 ≤ j := by omega + rw [zmodOfFE_mul_pure, zmodOfFE_mul_pure] + have hzidx : 2 * 64 - 1 - j / 4 = 128 - 4 * (j / 16) - (j % 16) / 4 - 1 := by omega + rw [hzidx, ← zetas_bridge_zmod (128 - 4 * (j / 16) - (j % 16) / 4 - 1) (by omega)] + ring + +/-- Spec layer-2 lane in flat form. -/ +private theorem spec_inv_layer_2_lane + (q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) (i : Nat) (hi : i < 256) : + (Spec.invert_ntt_layer_2_pure q 64#usize).val[i]! + = if i % 8 < 4 then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (q.val[i + 4]!) (q.val[i]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (q.val[i]!) (q.val[i - 4]!)) + (Spec.zeta_at (64 - 2 * (i / 16) - (i % 16) / 8 - 1)) := by + unfold Spec.invert_ntt_layer_2_pure + simp only [show (64#usize : Std.Usize).val = 64 from by rfl] + rw [flatten_chunk_map_lane (fun k => Spec.chunk_inv_ntt_layer_2_step_pure (Spec.chunk_at q k) + (Spec.zeta_at (64 - 2 * k - 1)) (Spec.zeta_at (64 - 2 * k - 2))) i hi (by simp)] + rw [chunk_inv_ntt_layer_2_step_lane _ _ _ (i % 16) (Nat.mod_lt _ (by decide))] + have hk : i / 16 < 16 := by omega + have him : i % 16 < 16 := by omega + have hmod8 : (i % 16) % 8 = i % 8 := by omega + by_cases hbr : i % 8 < 4 + · rw [if_pos (by rw [hmod8]; exact hbr)] + have ha : (i % 16) + 4 < 16 := by + have : (i % 16) % 8 < 4 := by rw [hmod8]; exact hbr + omega + rw [if_pos hbr] + rw [chunk_at_lane' q (i / 16) (i % 16) (by omega), + chunk_at_lane' q (i / 16) ((i % 16) + 4) ha] + congr 2 <;> omega + · rw [if_neg (by rw [hmod8]; exact hbr), if_neg hbr] + have hsub : 4 ≤ i % 16 := by + have : ¬ ((i % 16) % 8 < 4) := by rw [hmod8]; exact hbr + omega + rw [chunk_at_lane' q (i / 16) (i % 16) (by omega), + chunk_at_lane' q (i / 16) ((i % 16) - 4) (by omega)] + have hidx1 : 16 * (i / 16) + (i % 16) = i := by omega + have hidx2 : 16 * (i / 16) + ((i % 16) - 4) = i - 4 := by omega + rw [hidx1, hidx2] + have hzsel : (if (i % 16) / 8 = 0 then Spec.zeta_at (64 - 2 * (i / 16) - 1) + else Spec.zeta_at (64 - 2 * (i / 16) - 2)) + = Spec.zeta_at (64 - 2 * (i / 16) - (i % 16) / 8 - 1) := by + have hq : (i % 16) / 8 < 2 := by omega + interval_cases h : ((i % 16) / 8) <;> simp ; congr 1 + rw [hzsel] + +set_option maxHeartbeats 1000000 in +/-- **Layer-2 match.** -/ +private theorem ntt_inverse_layer_2_match + (q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hq : ∀ j : Nat, j < 256 → Canonical (q.val[j]!)) : + hacspec_ml_kem.invert_ntt.ntt_inverse_layer q 2#usize + = .ok (Spec.invert_ntt_layer_2_pure q 64#usize) := by + obtain ⟨len, hlen_def, hlenv⟩ := shl_one_ok 2#usize (numbits_ge 2 (by omega)) + have hlenv2 : len.val = 4 := by rw [hlenv]; decide + obtain ⟨groups, hgroups_def, hgroupsv⟩ := div128_ok len (by omega) + have hgroupsv2 : groups.val = 32 := by rw [hgroupsv, hlenv2] + rw [ntt_inverse_layer_eq_ok q 2#usize len groups hlen_def hgroups_def + (by omega) (by omega) (by omega) (by rw [hgroupsv, hlenv2]) + (fun i hi => by rw [hlenv2, hgroupsv2]; omega) + (fun i hi => ⟨fun hc => by rw [hlenv2] at *; omega, fun hc => by rw [hlenv2] at *; omega⟩) hq] + congr 1 + refine eq_of_zmod_lane_canon' _ _ (flat_layer_canon q len groups) + (invert_ntt_layer_2_canon q 64#usize hq) ?_ + intro j hj + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj]), + List.getElem_map, List.getElem_range] + rw [spec_inv_layer_2_lane q j hj] + rw [hlenv2, hgroupsv2] + by_cases hbr : j % 8 < 4 + · rw [show (2 * 4) = 8 from by norm_num, if_pos hbr, if_pos hbr] + rw [zmodOfFE_add_pure, zmodOfFE_add_pure]; ring + · rw [show (2 * 4) = 8 from by norm_num, if_neg hbr, if_neg hbr] + have hsub : 4 ≤ j := by omega + rw [zmodOfFE_mul_pure, zmodOfFE_mul_pure] + have hzidx : 2 * 32 - 1 - j / 8 = 64 - 2 * (j / 16) - (j % 16) / 8 - 1 := by omega + rw [hzidx, ← zetas_bridge_zmod (64 - 2 * (j / 16) - (j % 16) / 8 - 1) (by omega)] + ring + +/-- Spec layer-3 lane in flat form. -/ +private theorem spec_inv_layer_3_lane + (q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) (i : Nat) (hi : i < 256) : + (Spec.invert_ntt_layer_3_pure q 32#usize).val[i]! + = if i % 16 < 8 then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (q.val[i + 8]!) (q.val[i]!) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (q.val[i]!) (q.val[i - 8]!)) + (Spec.zeta_at (32 - (i / 16) - 1)) := by + unfold Spec.invert_ntt_layer_3_pure + simp only [show (32#usize : Std.Usize).val = 32 from by rfl] + rw [flatten_chunk_map_lane (fun k => Spec.chunk_inv_ntt_layer_3_step_pure (Spec.chunk_at q k) + (Spec.zeta_at (32 - k - 1))) i hi (by simp)] + rw [chunk_inv_ntt_layer_3_step_lane _ _ (i % 16) (Nat.mod_lt _ (by decide))] + have hk : i / 16 < 16 := by omega + have him : i % 16 < 16 := by omega + by_cases hbr : i % 16 < 8 + · rw [if_pos hbr, if_pos hbr] + rw [chunk_at_lane' q (i / 16) (i % 16) (by omega), + chunk_at_lane' q (i / 16) ((i % 16) + 8) (by omega)] + congr 2 <;> omega + · rw [if_neg hbr, if_neg hbr] + rw [chunk_at_lane' q (i / 16) (i % 16) (by omega), + chunk_at_lane' q (i / 16) ((i % 16) - 8) (by omega)] + have hidx1 : 16 * (i / 16) + (i % 16) = i := by omega + have hidx2 : 16 * (i / 16) + ((i % 16) - 8) = i - 8 := by omega + rw [hidx1, hidx2] + +set_option maxHeartbeats 1000000 in +/-- **Layer-3 match.** -/ +private theorem ntt_inverse_layer_3_match + (q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hq : ∀ j : Nat, j < 256 → Canonical (q.val[j]!)) : + hacspec_ml_kem.invert_ntt.ntt_inverse_layer q 3#usize + = .ok (Spec.invert_ntt_layer_3_pure q 32#usize) := by + obtain ⟨len, hlen_def, hlenv⟩ := shl_one_ok 3#usize (numbits_ge 3 (by omega)) + have hlenv2 : len.val = 8 := by rw [hlenv]; decide + obtain ⟨groups, hgroups_def, hgroupsv⟩ := div128_ok len (by omega) + have hgroupsv2 : groups.val = 16 := by rw [hgroupsv, hlenv2] + rw [ntt_inverse_layer_eq_ok q 3#usize len groups hlen_def hgroups_def + (by omega) (by omega) (by omega) (by rw [hgroupsv, hlenv2]) + (fun i hi => by rw [hlenv2, hgroupsv2]; omega) + (fun i hi => ⟨fun hc => by rw [hlenv2] at *; omega, fun hc => by rw [hlenv2] at *; omega⟩) hq] + congr 1 + refine eq_of_zmod_lane_canon' _ _ (flat_layer_canon q len groups) + (invert_ntt_layer_3_canon q 32#usize hq) ?_ + intro j hj + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj]), + List.getElem_map, List.getElem_range] + rw [spec_inv_layer_3_lane q j hj] + rw [hlenv2, hgroupsv2] + by_cases hbr : j % 16 < 8 + · rw [show (2 * 8) = 16 from by norm_num, if_pos hbr, if_pos hbr] + rw [zmodOfFE_add_pure, zmodOfFE_add_pure]; ring + · rw [show (2 * 8) = 16 from by norm_num, if_neg hbr, if_neg hbr] + have hsub : 8 ≤ j := by omega + rw [zmodOfFE_mul_pure, zmodOfFE_mul_pure] + have hzidx : 2 * 16 - 1 - j / 16 = 32 - (j / 16) - 1 := by omega + rw [hzidx, ← zetas_bridge_zmod (32 - (j / 16) - 1) (by omega)] + ring + +/-- Mod-chunk identity: `i % (2·(16·step)) = 16·((i/16) % (2·step)) + i%16`. -/ +private theorem mod_chunk_eq (i step : Nat) (hstep : 0 < step) : + i % (2 * (16 * step)) = 16 * ((i / 16) % (2 * step)) + i % 16 := by + have h1 : i = 16 * (i / 16) + i % 16 := (Nat.div_add_mod i 16).symm + have key : (16 * (i / 16)) % (16 * (2 * step)) = 16 * ((i / 16) % (2 * step)) := + Nat.mul_mod_mul_left 16 (i / 16) (2 * step) + have h16 : i % 16 < 16 := Nat.mod_lt _ (by decide) + have hxlt : (i / 16) % (2 * step) + 1 ≤ 2 * step := Nat.mod_lt _ (by omega) + have hml : 16 * ((i / 16) % (2 * step) + 1) ≤ 16 * (2 * step) := + Nat.mul_le_mul (Nat.le_refl 16) hxlt + have hbound : 16 * ((i / 16) % (2 * step)) + i % 16 < 16 * (2 * step) := by + rw [Nat.mul_add] at hml; omega + have hstep_eq : 2 * (16 * step) = 16 * (2 * step) := by ring + have h16' : i % 16 < 16 * (2 * step) := by + have hge : 16 * 1 ≤ 16 * (2 * step) := Nat.mul_le_mul (Nat.le_refl 16) (by omega) + omega + rw [hstep_eq] + conv_lhs => rw [h1] + rw [Nat.add_mod, key, Nat.mod_eq_of_lt h16', Nat.mod_eq_of_lt hbound] + +/-- Spec layer-4+ lane in flat form, parametrized by `layer`/`zeta_i`, with `step` + the chunk step `(1<< Spec.chunk_inv_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at q)) (by simp)) + layer (fun group => Spec.zeta_at (zeta_i.val - 1 - group)) c) i hi (by simp)] + have hc : i / 16 < 16 := by omega + have hℓ : i % 16 < 16 := Nat.mod_lt _ (by decide) + have hstep16 : (1 <<< layer.val) / 16 = step := hstep.symm + -- partner-in-range for the a-branch + have hub : (i / 16) % (2 * ((1 <<< layer.val) / 16)) < (1 <<< layer.val) / 16 → + (i / 16) + (1 <<< layer.val) / 16 < 16 := by + rw [hstep16]; exact fun hoff => layer4_partner_lt (i / 16) step hc hstep_pos hdvd hoff + rw [chunk_inv_at_layer_4_plus_chunks0_eq q layer + (fun group => Spec.zeta_at (zeta_i.val - 1 - group)) (i / 16) hc hub (by simp)] + rw [hstep16] + -- relate flat a/b decision to chunk a/b decision via mod_chunk_eq + have hmce : i % (2 * len) = 16 * ((i / 16) % (2 * step)) + i % 16 := by + rw [hlen]; exact mod_chunk_eq i step hstep_pos + have hdecf : (i % (2 * len) < len) ↔ ((i / 16) % (2 * step) < step) := by + rw [hmce, hlen] + constructor + · intro h; by_contra hco; push Not at hco + have : 16 * step ≤ 16 * ((i / 16) % (2 * step)) := Nat.mul_le_mul (Nat.le_refl 16) hco + omega + · intro h + have : 16 * ((i / 16) % (2 * step) + 1) ≤ 16 * step := Nat.mul_le_mul (Nat.le_refl 16) (by omega) + rw [Nat.mul_add] at this; have h16 : i % 16 < 16 := Nat.mod_lt _ (by decide); omega + by_cases hbr : (i / 16) % (2 * step) < step + · rw [if_pos hbr, if_pos (hdecf.mpr hbr)] + have hub' : (i / 16) + step < 16 := layer4_partner_lt (i / 16) step hc hstep_pos hdvd hbr + rw [chunk_inv_pair_butterfly_a_lane _ _ (i % 16) hℓ] + rw [chunk_at_lane' q (i / 16) (i % 16) hℓ, + chunk_at_lane' q ((i / 16) + step) (i % 16) hℓ] + have e1 : 16 * (i / 16) + (i % 16) = i := by omega + have e2 : 16 * ((i / 16) + step) + (i % 16) = i + len := by + rw [hlen, Nat.mul_add]; omega + rw [e1, e2] + · rw [if_neg hbr, if_neg (fun hc2 => hbr (hdecf.mp hc2))] + rw [chunk_inv_pair_butterfly_b_lane _ _ _ (i % 16) hℓ] + rw [chunk_at_lane' q (i / 16) (i % 16) hℓ, + chunk_at_lane' q ((i / 16) - step) (i % 16) hℓ] + have e1 : 16 * (i / 16) + (i % 16) = i := by omega + have hstep_le : step ≤ i / 16 := by + have hr : (i / 16) % (2 * step) ≤ i / 16 := Nat.mod_le _ _ + omega + have e2 : 16 * ((i / 16) - step) + (i % 16) = i - len := by + rw [hlen, Nat.mul_sub]; omega + rw [e1, e2] + +set_option maxHeartbeats 1000000 in +/-- **Layer-4+ match** for a concrete `layer`/`zeta_i`, `step`/`len`. -/ +private theorem ntt_inverse_layer_4_plus_match + (q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i layer : Std.Usize) (lenu groups : Std.Usize) (step : Nat) + (hlen_def : (1#usize <<< layer : Result Std.Usize) = .ok lenu) + (hgroups_def : (128#usize / lenu : Result Std.Usize) = .ok groups) + (hstep : step = (1 <<< layer.val) / 16) (hlenv : lenu.val = 16 * step) + (hstep_pos : 0 < step) (hdvd : (2 * step) ∣ 16) + (hlen2 : 2 ≤ lenu.val) (hlen128 : lenu.val ≤ 128) + (hgroupsv : groups.val = 128 / lenu.val) (hlg : lenu.val * groups.val = 128) + (hzeta : ∀ i : Nat, i < 256 → 2 * groups.val - 1 - i / (2 * lenu.val) + = zeta_i.val - 1 - (i / 16) / (2 * step)) + (hzeta_lt : ∀ i : Nat, i < 256 → ¬ (i % (2 * lenu.val) < lenu.val) → + zeta_i.val - 1 - (i / 16) / (2 * step) < 128) + (hq : ∀ j : Nat, j < 256 → Canonical (q.val[j]!)) : + hacspec_ml_kem.invert_ntt.ntt_inverse_layer q layer + = .ok (Spec.invert_ntt_layer_4_plus_pure q zeta_i layer) := by + have h256 : 2 * lenu.val * groups.val = 256 := by rw [show 2 * lenu.val * groups.val = 2 * (lenu.val * groups.val) from by ring, hlg] + rw [ntt_inverse_layer_eq_ok q layer lenu groups hlen_def hgroups_def + (by omega) (by omega) hlen128 hgroupsv + (fun i hi => Nat.div_lt_of_lt_mul (by omega)) + (fun i hi => ⟨fun hc => by + -- a-side: i % (2*len) < len ⇒ i + len < 256 + have hdm : i = 2 * lenu.val * (i / (2 * lenu.val)) + i % (2 * lenu.val) := + (Nat.div_add_mod i (2 * lenu.val)).symm + have hblk : i / (2 * lenu.val) < groups.val := Nat.div_lt_of_lt_mul (by omega) + have hbk1 : i / (2 * lenu.val) + 1 ≤ groups.val := by omega + have hmul : 2 * lenu.val * (i / (2 * lenu.val) + 1) ≤ 2 * lenu.val * groups.val := + Nat.mul_le_mul (Nat.le_refl _) hbk1 + rw [Nat.mul_add] at hmul; omega, + fun hc => by + have hle : i % (2 * lenu.val) ≤ i := Nat.mod_le _ _ + omega⟩) hq] + congr 1 + refine eq_of_zmod_lane_canon' _ _ (flat_layer_canon q lenu groups) + (invert_ntt_layer_4_plus_canon q zeta_i layer hq) ?_ + intro j hj + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj]), + List.getElem_map, List.getElem_range] + rw [spec_inv_layer_4_plus_lane q zeta_i layer step lenu.val hstep hlenv hstep_pos hdvd j hj] + by_cases hbr : j % (2 * lenu.val) < lenu.val + · rw [if_pos hbr, if_pos hbr] + · rw [if_neg hbr, if_neg hbr] + rw [zmodOfFE_mul_pure, zmodOfFE_mul_pure] + rw [hzeta j hj, ← zetas_bridge_zmod (zeta_i.val - 1 - (j / 16) / (2 * step)) (hzeta_lt j hj hbr)] + ring + +/-- Concrete cross-chunk-layer match, instantiating `ntt_inverse_layer_4_plus_match` + for a layer `N ∈ {4,5,6,7}` (given as `layer : Std.Usize` with `4 ≤ N ≤ 7`). + The Spec zeta-index `zeta_i = 128 / 2^N` and `step = 2^N / 16`. -/ +private theorem ntt_inverse_layer_4plus_match' + (q : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (layer zeta_i : Std.Usize) (N : Nat) + (hN_def : layer.val = N) (hN_lo : 4 ≤ N) (hN_hi : N ≤ 7) + (hzi : zeta_i.val = 256 / 2 ^ N) + (hq : ∀ j : Nat, j < 256 → Canonical (q.val[j]!)) : + hacspec_ml_kem.invert_ntt.ntt_inverse_layer q layer + = .ok (Spec.invert_ntt_layer_4_plus_pure q zeta_i layer) := by + obtain ⟨lenu, hlen_def, hlenv⟩ := shl_one_ok layer (by rw [hN_def]; exact numbits_ge N (by omega)) + have hlenval : lenu.val = 2 ^ N := by rw [hlenv, hN_def] + obtain ⟨groups, hgroups_def, hgroupsv⟩ := div128_ok lenu (by rw [hlenval]; positivity) + set step : Nat := 2 ^ N / 16 with hstep_def + have hstep_pos : 0 < step := by + rw [hstep_def]; have : 16 ≤ 2 ^ N := by calc 16 = 2 ^ 4 := by norm_num + _ ≤ 2 ^ N := Nat.pow_le_pow_right (by omega) hN_lo + omega + have h16dvd : (16 : Nat) ∣ 2 ^ N := by + calc (16 : Nat) = 2 ^ 4 := by norm_num + _ ∣ 2 ^ N := pow_dvd_pow 2 hN_lo + have hlenv16 : lenu.val = 16 * step := by + rw [hlenval, hstep_def]; exact (Nat.mul_div_cancel' h16dvd).symm + have hstep_eq : step = (1 <<< layer.val) / 16 := by + rw [hstep_def, hN_def, Nat.shiftLeft_eq, Nat.one_mul] + have hdvd : (2 * step) ∣ 16 := by + rw [hstep_def]; interval_cases N <;> decide + have hlen2 : 2 ≤ lenu.val := by rw [hlenval]; calc 2 = 2 ^ 1 := by norm_num + _ ≤ 2 ^ N := Nat.pow_le_pow_right (by omega) (by omega) + have hlen128 : lenu.val ≤ 128 := by + rw [hlenval]; calc 2 ^ N ≤ 2 ^ 7 := Nat.pow_le_pow_right (by omega) hN_hi + _ = 128 := by norm_num + have hgroupsv' : groups.val = 128 / lenu.val := by rw [hgroupsv, hlenval] + have hlg : lenu.val * groups.val = 128 := by + rw [hgroupsv', hlenval]; rw [Nat.mul_div_cancel'] + calc (2 : Nat) ^ N ∣ 2 ^ 7 := pow_dvd_pow 2 hN_hi + _ = 128 := by norm_num + have h2lenmax : 2 * lenu.val ≤ Std.Usize.max := by + have : 2 * lenu.val ≤ 256 := by omega + have h256 : (256 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + -- `zeta_i = 2 * groups` and `2 * lenu = 32 * step` + have hzi_eq : zeta_i.val = 2 * groups.val := by + rw [hzi, hgroupsv', hlenval] + have h128dvd : (2 : Nat) ^ N ∣ 128 := by + calc (2 : Nat) ^ N ∣ 2 ^ 7 := pow_dvd_pow 2 hN_hi + _ = 128 := by norm_num + -- `256 / 2^N = 2 * (128 / 2^N)` since `2^N ∣ 128` + rw [show (256 : Nat) = 2 * 128 from by norm_num, Nat.mul_div_assoc 2 h128dvd] + have h2len32 : 2 * lenu.val = 32 * step := by rw [hlenv16]; ring + rw [ntt_inverse_layer_4_plus_match q zeta_i layer lenu groups step + hlen_def hgroups_def hstep_eq hlenv16 hstep_pos hdvd hlen2 hlen128 hgroupsv' hlg + (fun i hi => by + rw [hzi_eq, h2len32] + have hnest : (i / 16) / (2 * step) = i / (32 * step) := by + rw [Nat.div_div_eq_div_mul]; ring_nf + rw [hnest]) + (fun i hi hbr => by + -- `zeta_i - 1 - (i/16)/(2*step) ≤ zeta_i - 1 < 128` since `zeta_i ≤ 128` + have h2g : 2 * groups.val ≤ lenu.val * groups.val := + Nat.mul_le_mul_right _ hlen2 + have hgle : 2 * groups.val ≤ 128 := by rw [hlg] at h2g; exact h2g + have hzle : zeta_i.val ≤ 128 := hzi_eq ▸ hgle + calc zeta_i.val - 1 - (i / 16) / (2 * step) + ≤ zeta_i.val - 1 := Nat.sub_le _ _ + _ ≤ 128 - 1 := Nat.sub_le_sub_right hzle 1 + _ < 128 := by omega) + hq] + +/-- **C2.** The hacspec `ntt_inverse_butterflies` (7-layer `let`-chain) matches the + Spec `invert_ntt_montgomery_pure` (the same 7 layers with zeta-thread + `128 → 64 → 32 → 16 → 8 → 4 → 2`), for canonical input. Each layer match is the + corresponding `_match` lemma; canonicity threads via the `_canon` lemmas. -/ +private theorem ntt_inverse_butterflies_eq_invert_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hp : ∀ j : Nat, j < 256 → Canonical (p.val[j]!)) : + hacspec_ml_kem.invert_ntt.ntt_inverse_butterflies p + = .ok (Spec.invert_ntt_montgomery_pure p) := by + unfold hacspec_ml_kem.invert_ntt.ntt_inverse_butterflies Spec.invert_ntt_montgomery_pure + -- Layer 1 + rw [ntt_inverse_layer_1_match p hp]; simp only [bind_tc_ok] + have hc1 : CanonArr (Spec.invert_ntt_layer_1_pure p 128#usize) := + invert_ntt_layer_1_canon p 128#usize hp + -- Layer 2 + rw [ntt_inverse_layer_2_match _ hc1]; simp only [bind_tc_ok] + have hc2 : CanonArr (Spec.invert_ntt_layer_2_pure (Spec.invert_ntt_layer_1_pure p 128#usize) 64#usize) := + invert_ntt_layer_2_canon _ 64#usize hc1 + -- Layer 3 + rw [ntt_inverse_layer_3_match _ hc2]; simp only [bind_tc_ok] + have hc3 : CanonArr (Spec.invert_ntt_layer_3_pure _ 32#usize) := + invert_ntt_layer_3_canon _ 32#usize hc2 + -- Layer 4 (cross-chunk; zeta_i = 16, step = 1) + rw [ntt_inverse_layer_4plus_match' _ 4#usize 16#usize 4 rfl (by omega) (by omega) (by decide) hc3] + simp only [bind_tc_ok] + have hc4 : CanonArr (Spec.invert_ntt_layer_4_plus_pure _ 16#usize 4#usize) := + invert_ntt_layer_4_plus_canon _ 16#usize 4#usize hc3 + -- Layer 5 (zeta_i = 8, step = 2) + rw [ntt_inverse_layer_4plus_match' _ 5#usize 8#usize 5 rfl (by omega) (by omega) (by decide) hc4] + simp only [bind_tc_ok] + have hc5 : CanonArr (Spec.invert_ntt_layer_4_plus_pure _ 8#usize 5#usize) := + invert_ntt_layer_4_plus_canon _ 8#usize 5#usize hc4 + -- Layer 6 (zeta_i = 4, step = 4) + rw [ntt_inverse_layer_4plus_match' _ 6#usize 4#usize 6 rfl (by omega) (by omega) (by decide) hc5] + simp only [bind_tc_ok] + have hc6 : CanonArr (Spec.invert_ntt_layer_4_plus_pure _ 4#usize 6#usize) := + invert_ntt_layer_4_plus_canon _ 4#usize 6#usize hc5 + -- Layer 7 (zeta_i = 2, step = 8) + rw [ntt_inverse_layer_4plus_match' _ 7#usize 2#usize 7 rfl (by omega) (by omega) (by decide) hc6] + +end InvertButterfliesC2 + +theorem ntt_inverse_eq_scaleZ_invert_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hp : ∀ j : Nat, j < 256 → + libcrux_iot_ml_kem.Spec.Pure.Canonical (p.val[j]!)) : + hacspec_ml_kem.invert_ntt.ntt_inverse p + = .ok (scaleZ 3303 (Spec.invert_ntt_montgomery_pure p)) := by + have hcanon : ∀ j : Nat, j < 256 → + libcrux_iot_ml_kem.Spec.Pure.Canonical + ((Spec.invert_ntt_montgomery_pure p).val[j]!) := by + unfold Spec.invert_ntt_montgomery_pure + exact invert_ntt_layer_4_plus_canon _ 2#usize 7#usize + (invert_ntt_layer_4_plus_canon _ 4#usize 6#usize + (invert_ntt_layer_4_plus_canon _ 8#usize 5#usize + (invert_ntt_layer_4_plus_canon _ 16#usize 4#usize + (invert_ntt_layer_3_canon _ 32#usize + (invert_ntt_layer_2_canon _ 64#usize + (invert_ntt_layer_1_canon p 128#usize hp)))))) + exact ntt_inverse_reduce_eq p _ hcanon (ntt_inverse_butterflies_eq_invert_pure p hp) + +/-! ## D (hacspec side) — `sub_polynomials a (scaleZ 512 b) ≡ subtract_reduce_pure a b`. + + `Bridges.zmodOfFE_subtract_reduce_pure_lane` gives, for canonical `a[j]`, + `zmodOfFE (subtract_reduce_pure a b)[j] = zmodOfFE a[j] - 512 * zmodOfFE b[j]`. + `sub_polynomials a c` is per-lane `sub_pure a[j] c[j]`; with `c = scaleZ 512 b`, + `zmodOfFE c[j] = 512 * zmodOfFE b[j]` (`scaleZ_lane`), so both sides have the + same canonical lanes. -/ +section SubPolyScaleZ + +open libcrux_iot_ml_kem.Spec.Pure (Canonical) +/-- `feOfZMod z` is canonical (local copy; the `InvertScaleZ` one is `private`). -/ +private theorem canon_feOfZMod' (z : ZMod 3329) : Canonical (feOfZMod z) := by + unfold Canonical feOfZMod hacspec_ml_kem.parameters.FIELD_MODULUS + show (BitVec.ofNat 16 z.val).toNat < _ + rw [BitVec.toNat_ofNat] + have hz : z.val < 3329 := ZMod.val_lt z + have : z.val % 2 ^ 16 = z.val := Nat.mod_eq_of_lt (by omega) + simp only [this]; simpa using hz + +/-- `scaleZ c p` is canonical per lane. -/ +private theorem canonArr_scaleZ' (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) : Canonical ((scaleZ c p).val[j]!) := by + unfold scaleZ + rw [mkN_map_lane' (fun k => feOfZMod (c * zmodOfFE (p.val[k]!))) j hj _] + exact canon_feOfZMod' _ + +/-- Canonical round-trip (local copy). -/ +private theorem feOfZMod_zmodOfFE_of_canon' + (fe : hacspec_ml_kem.parameters.FieldElement) (h : Canonical fe) : + feOfZMod (zmodOfFE fe) = fe := by + have h' : fe.val.val < 3329 := by + unfold Canonical hacspec_ml_kem.parameters.FIELD_MODULUS at h; simpa using h + unfold feOfZMod zmodOfFE + have hzval : ((fe.val.val : ZMod 3329)).val = fe.val.val := ZMod.val_natCast_of_lt h' + rw [hzval] + have hfeval : fe.val.val < 2 ^ 16 := by + have h_p : (3329 : Nat) ≤ 2 ^ 16 := by decide + omega + have hfebv : BitVec.ofNat 16 fe.val.val = fe.val.bv := by + apply BitVec.eq_of_toNat_eq + rw [BitVec.toNat_ofNat] + show fe.val.val % 2 ^ 16 = fe.val.bv.toNat + rw [Nat.mod_eq_of_lt hfeval]; rfl + show ({ val := ⟨BitVec.ofNat 16 fe.val.val⟩ } : + hacspec_ml_kem.parameters.FieldElement) = fe + rw [hfebv] + +/-- Two canonical 256-arrays with equal `zmodOfFE` lanes are equal. -/ +private theorem eq_of_zmod_lane_canon + (u v : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hcu : ∀ j : Nat, j < 256 → Canonical (u.val[j]!)) + (hcv : ∀ j : Nat, j < 256 → Canonical (v.val[j]!)) + (hz : ∀ j : Nat, j < 256 → zmodOfFE (u.val[j]!) = zmodOfFE (v.val[j]!)) : + u = v := by + apply Subtype.ext + apply List.ext_getElem + · rw [Aeneas.Std.Array.length_eq u, Aeneas.Std.Array.length_eq v] + · intro j hj1 _hj2 + have hj : j < 256 := by rw [Aeneas.Std.Array.length_eq u] at hj1; simpa using hj1 + have heq : u.val[j]! = v.val[j]! := by + rw [← feOfZMod_zmodOfFE_of_canon' (u.val[j]!) (hcu j hj), + ← feOfZMod_zmodOfFE_of_canon' (v.val[j]!) (hcv j hj), hz j hj] + have huj : u.val[j]! = u.val[j] := + getElem!_pos u.val j (by rw [Aeneas.Std.Array.length_eq u]; exact hj) + have hvj : v.val[j]! = v.val[j] := + getElem!_pos v.val j (by rw [Aeneas.Std.Array.length_eq v]; exact hj) + rw [← huj, ← hvj]; exact heq + +-- The 8-bind monadic do-block in the `sub_polynomials` closure needs a deeper +-- elaboration recursion limit than the default (512). Mirrors the createi/from_fn +-- proof family in `FCTargets.lean` (e.g. `set_option maxRecDepth 4000 in` at the +-- ntt closure proofs); the `add_polynomials` template stays under 512 only because +-- its closure body is one bind shorter. +-- `createi_pure_eq` over a 2-tuple closure state `(a, c)` still needs a deeper +-- recursion limit for the `parameters.createi`→`from_fn` defeq (sanctioned +-- createi exception); the single-state `reduce_polynomial_eq_ok` does not. +set_option maxRecDepth 4000 in +set_option maxHeartbeats 1000000 in +/-- `sub_polynomials a c` reduces to the per-lane `sub_pure` array, given `a` and + `c` canonical (the closure body is byte-identical to `FieldElement.sub`, so + `sub_eq_ok` applies). Mirrors `matrix_add_polynomials_eq_ok` (FCTargets). -/ +private theorem sub_polynomials_eq_ok + (a c : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (ha : ∀ k : Nat, k < 256 → Canonical (a.val[k]!)) + (hc : ∀ k : Nat, k < 256 → Canonical (c.val[k]!)) : + hacspec_ml_kem.matrix.sub_polynomials a c + = .ok ⟨(List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[k]!) (c.val[k]!)), + by simp [List.length_map, List.length_range]⟩ := by + set f : Nat → hacspec_ml_kem.parameters.FieldElement := + fun k => libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[k]!) (c.val[k]!) with hf_def + have hpure : ∀ k : Nat, k < (256#usize : Std.Usize).val → + (hacspec_ml_kem.matrix.sub_polynomials.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + : CoreModels.core.ops.function.Fn _ _ _).FnMutInst.call_mut + (a, c) ⟨BitVec.ofNat _ k⟩ + = .ok (f k, (a, c)) := by + intro k hk + have hk' : k < 256 := hk + show hacspec_ml_kem.matrix.sub_polynomials.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + (a, c) ⟨BitVec.ofNat _ k⟩ = .ok (f k, (a, c)) + unfold hacspec_ml_kem.matrix.sub_polynomials.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + unfold hacspec_ml_kem.matrix.sub_polynomials.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement.call + have hk_us : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + have : k < 2^System.Platform.numBits := by + have hbits : 2^16 ≤ 2^System.Platform.numBits := + Nat.pow_le_pow_right (by decide) (by + cases System.Platform.numBits_eq with + | inl h => rw [h]; decide + | inr h => rw [h]; decide) + omega + exact this + have ha_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < a.length := by + rw [hk_us]; show k < a.val.length + rw [a.property]; exact hk + have hc_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < c.length := by + rw [hk_us]; show k < c.val.length + rw [c.property]; exact hk + have h_a_idx : + Std.Array.index_usize a (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (a.val[k]!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a + (⟨BitVec.ofNat _ k⟩ : Std.Usize) ha_len + rw [hk_us] at this; exact this + have h_c_idx : + Std.Array.index_usize c (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (c.val[k]!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq c + (⟨BitVec.ofNat _ k⟩ : Std.Usize) hc_len + rw [hk_us] at this; exact this + have h_sub := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_eq_ok + (a.val[k]!) (c.val[k]!) (ha k hk') (hc k hk') + change (do + let fe ← (do + let fe ← Std.Array.index_usize a ⟨BitVec.ofNat _ k⟩ + let i ← lift (Std.UScalar.cast .U32 fe.val) + let i1 ← lift (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS) + let i2 ← i + i1 + let fe1 ← Std.Array.index_usize c ⟨BitVec.ofNat _ k⟩ + let i3 ← lift (Std.UScalar.cast .U32 fe1.val) + let i4 ← i2 - i3 + let i5 ← lift (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS) + let i6 ← i4 % i5 + let i7 ← lift (Std.UScalar.cast .U16 i6) + hacspec_ml_kem.parameters.FieldElement.new i7) + Result.ok (fe, a, c)) = Result.ok (f k, a, c) + rw [h_a_idx]; simp only [bind_tc_ok] + rw [h_c_idx]; simp only [bind_tc_ok] + unfold hacspec_ml_kem.parameters.FieldElement.sub at h_sub + rw [h_sub] + simp only [bind_tc_ok, hf_def] + unfold hacspec_ml_kem.matrix.sub_polynomials + exact libcrux_iot_ml_kem.Util.CreateI.createi_pure_eq 256#usize + hacspec_ml_kem.matrix.sub_polynomials.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + (a, c) f hpure + +set_option maxHeartbeats 1000000 in +theorem sub_polynomials_scaleZ_eq + (a b : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (ha : ∀ j : Nat, j < 256 → + libcrux_iot_ml_kem.Spec.Pure.Canonical (a.val[j]!)) : + hacspec_ml_kem.matrix.sub_polynomials a (scaleZ 512 b) + = .ok (Spec.subtract_reduce_pure a b) := by + have hc : ∀ k : Nat, k < 256 → Canonical ((scaleZ 512 b).val[k]!) := + fun k hk => canonArr_scaleZ' 512 b k hk + rw [sub_polynomials_eq_ok a (scaleZ 512 b) ha hc] + -- The reduced LHS array (set L); show it equals `subtract_reduce_pure a b`. + set L : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + ⟨(List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[k]!) ((scaleZ 512 b).val[k]!)), + by simp [List.length_map, List.length_range]⟩ with hL_def + have hL_lane : ∀ j : Nat, j < 256 → + L.val[j]! = libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[j]!) ((scaleZ 512 b).val[j]!) := by + intro j hj + show ((List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[k]!) ((scaleZ 512 b).val[k]!)))[j]! = _ + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + congr 1 + apply eq_of_zmod_lane_canon + · -- L lanes canonical + intro j hj + rw [hL_lane j hj] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_sub_pure _ _ (ha j hj) (hc j hj) + · -- subtract_reduce_pure lanes canonical + intro j hj + have hℓ : j % 16 < 16 := Nat.mod_lt _ (by decide) + have hjeq : 16 * (j / 16) + j % 16 = j := by omega + unfold Spec.subtract_reduce_pure + rw [flatten_chunk_map_lane (fun k => Spec.chunk_subtract_reduce_pure + (Spec.chunk_at a k) (Spec.chunk_at b k)) j hj (by simp)] + unfold Spec.chunk_subtract_reduce_pure + rw [mkN_map_lane' _ (j % 16) hℓ] + rw [chunk_at_lane' a (j / 16) (j % 16) hℓ, hjeq] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_sub_pure _ _ (ha j hj) + (libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _) + · -- per-lane zmodOfFE equality + intro j hj + rw [hL_lane j hj] + rw [zmodOfFE_sub_pure _ _ (ha j hj) (hc j hj)] + rw [scaleZ_lane 512 b j hj] + rw [zmodOfFE_subtract_reduce_pure_lane a b j hj (ha j hj)] + +end SubPolyScaleZ + +end libcrux_iot_ml_kem.Matrix.ComputeMessage.Hacspec \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/Impl.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/Impl.lean new file mode 100644 index 00000000..64ab8069 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeMessage/Impl.lean @@ -0,0 +1,1329 @@ +/- + # `Matrix/ComputeMessage/Impl.lean` — L7.4 S1 loop FC. + + Houses the S1 loop FC for `matrix.compute_message_loop`: the cache-free + analog of `compute_As_plus_e_loop0_fc`. The loop folds over + `i ∈ [0, K)`, each iteration applying + `accumulating_ntt_multiply secret_as_ntt[i] u_as_ntt[i] acc` into a + single I32[256] accumulator (no cache, no matrix). The POST is the + K-fold of the per-step POST of `accumulating_ntt_multiply_poly_fc`: a + `mont_reduce_pure`-per-lane `List.range`-foldl characterization plus + the running bound `≤ accumulator[n] + k·2^25`. + + Mirrors `Stage2UseCacheFC.row_i_inv` / `compute_As_plus_e_loop1_loop0_fc` — the + 2-conjunct, accumulator-only (no cache) precedent — but drops the + matrix-row index (two source arrays `secret_as_ntt[c]`, `u_as_ntt[c]` + indexed directly by `c`), and uses the plain + `accumulating_ntt_multiply_poly_fc` per-step lemma instead of the + `_use_cache_poly_fc` variant. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeMessage.Bridges + +namespace libcrux_iot_ml_kem.Matrix.ComputeMessage.Impl +open libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +/-- Local copy of FCTargets' `private triple_exists_ok_fc`: a `True`-pre + Triple yielding `.ok` with the post is equivalent to an existential + `.ok` witness. The L7 files cannot see the private original, so this is + re-derived from the public `Std.Do.Triple`/`WP.wp` unfolding. -/ +private theorem triple_exists_ok_fc {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- Local copy of FCTargets' `private triple_of_ok_fc`: a `True`-pre Triple + follows from an `.ok` reduction plus the post on the witness. -/ +private theorem triple_of_ok_fc {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +/-! ## S1 — the corrected loop FC for `matrix.compute_message_loop`. + + Namespace `S1LoopFC` provides the loop invariant + step-post predicates + used to characterize `matrix.compute_message_loop` via + `loop_range_spec_usize`. The accumulator state is a single I32[256] + array (no cache) — `Acc := Std.Array Std.I32 256#usize`. -/ + +namespace S1LoopFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +abbrev Acc := UseCacheFC.Acc + +/-- 2-conjunct invariant for the message-accumulation loop. Tracks: + (1) accumulator characterization: for each (chunk j, lane ℓ) in + `[0, 16)²`, `Spec.mont_reduce_pure (lift_fe_int acc[16j+ℓ].val)` + equals init plus the canonical-form sum of column contributions + `secret_as_ntt[c] ⊛ u_as_ntt[c]` from columns `[0, k)`. + (2) accumulator bound: `|acc.val[n]| ≤ |acc_init.val[n]| + k · 2^25`. + + Cache-free analog of `Stage2UseCacheFC.row_i_inv` with the + matrix-row index dropped: both source arrays are indexed directly by + the column `c`. -/ +def loop_inv {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Acc) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + -- (1) Per-(chunk j, lane ℓ) accumulator: canonical-form k-column sum. + (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range k.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (secret_as_ntt.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (u_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + -- (2) Accumulator bound grows by 2^25 per column iteration. + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + k.val * 2^25)) + +/-- Step-post for `loop_range_spec_usize` over the accumulator only. -/ +def step_post {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Acc) (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : + Prop := + match r with + | .cont (iter', acc') => + k.val < K.val ∧ iter'.«end» = K + ∧ iter'.start.val = k.val + 1 + ∧ (loop_inv secret_as_ntt u_as_ntt acc_init iter'.start acc').holds + | .done y => (loop_inv secret_as_ntt u_as_ntt acc_init K y).holds + +end S1LoopFC + +-- Memory hygiene (rule 1 / SKILL §5.7 Idiom 2). Mirrors `L7_1b_irreducible` +-- — heavy POST predicates are made locally irreducible +-- across the step lemma + outer Triple so that elaboration does not +-- whnf-explode through the 2-conjunct `loop_inv` body or the nested +-- `∀ j, ∀ ℓ` accumulator characterization. -- we do NOT mark +-- `S1LoopFC.loop_inv` / `step_post` irreducible. +section L7_4_irreducible +attribute [local irreducible] accumulating_ntt_multiply_poly_post +attribute [local irreducible] Spec.ntt_multiply_pure_no_acc +attribute [local irreducible] Spec.mont_reduce_pure + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the message-accumulation loop. Given the + `loop_inv` invariant at step k and the strengthened PRE bounds, executing + one body iteration of `matrix.compute_message_loop.body` produces the + `step_post` (either `.cont` advancing the invariant to k+1 or `.done` + capping at K). + + Mirrors `compute_As_plus_e_loop1_loop0_step_lemma_fc` + but cache-free: no `matrix.entry` (read `secret_as_ntt[k]`, `u_as_ntt[k]` + directly), no cache read, and the per-column forward dep is + `accumulating_ntt_multiply_poly_fc` which returns a + single accumulator (no pair). -/ +private theorem compute_message_loop_step_lemma_fc + {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : S1LoopFC.Acc) + (h_secret_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((secret_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_u_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((u_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, + (acc_init.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) + (acc : S1LoopFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ K.val) + (h_inv : (S1LoopFC.loop_inv secret_as_ntt u_as_ntt acc_init k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_message_loop.body + (vectortraitsOperationsInst := portable_ops_inst) secret_as_ntt u_as_ntt + { start := k, «end» := K } acc + ⦃ ⇓ r => ⌜ S1LoopFC.step_post secret_as_ntt u_as_ntt acc_init k r ⌝ ⦄ := by + have h_secret_len : secret_as_ntt.length = K.val := Std.Array.length_eq secret_as_ntt + have h_u_len : u_as_ntt.length = K.val := Std.Array.length_eq u_as_ntt + have h_acc_len : acc.length = 256 := Std.Array.length_eq acc + have h_acc_init_len : acc_init.length = 256 := Std.Array.length_eq acc_init + -- Destructure the 2-conjunct invariant. + obtain ⟨h_inv_acc, h_inv_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.matrix.compute_message_loop.body + by_cases h_lt : k.val < K.val + · -- `Some k` branch. + -- (1) IteratorRange.next reduces to .ok (some k, { start := s_iter, end := K }). + have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, + ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- (2) Array.index_usize secret_as_ntt k reduces to .ok secret_as_ntt[k.val]!. + set t_secret : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + secret_as_ntt.val[k.val]! with ht_secret_def + have h_idx_secret : Aeneas.Std.Array.index_usize secret_as_ntt k = .ok t_secret := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq secret_as_ntt k + (by rw [h_secret_len]; exact h_lt) + -- (3) Array.index_usize u_as_ntt k reduces to .ok u_as_ntt[k.val]!. + set t_u : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + u_as_ntt.val[k.val]! with ht_u_def + have h_idx_u : Aeneas.Std.Array.index_usize u_as_ntt k = .ok t_u := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq u_as_ntt k + (by rw [h_u_len]; exact h_lt) + -- (4) Apply L6.3 per-column forward dep at column k. + have h_t_secret_bnd : ∀ a : Fin 16, ∀ b : Fin 16, + ((t_secret.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_secret_bnd ⟨k.val, h_lt⟩ a b + have h_t_u_bnd : ∀ a : Fin 16, ∀ b : Fin 16, + ((t_u.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_u_bnd ⟨k.val, h_lt⟩ a b + -- Current acc bound ≤ 2^30: combine inv conjunct (2) with budget PRE. + have h_acc_cur_bnd : ∀ n : Fin 256, (acc.val[n.val]!).val.natAbs ≤ 2^30 := by + intro n + have hb := h_inv_acc_bnd n.val n.isLt + have hp := h_acc_bnd n + have hk_le : k.val * 2^25 ≤ K.val * 2^25 := Nat.mul_le_mul_right _ h_le + omega + obtain ⟨acc1, h_acc1_eq, h_acc1_bnd_rel, h_acc1_post⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_poly_fc t_secret t_u acc + h_t_secret_bnd h_t_u_bnd h_acc_cur_bnd) + -- (5) Body equation. + have h_body : + libcrux_iot_ml_kem.matrix.compute_message_loop.body + (vectortraitsOperationsInst := portable_ops_inst) secret_as_ntt u_as_ntt + { start := k, «end» := K } acc + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), acc1)) := by + unfold libcrux_iot_ml_kem.matrix.compute_message_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let pre ← Aeneas.Std.Array.index_usize secret_as_ntt k + let pre1 ← Aeneas.Std.Array.index_usize u_as_ntt k + let accumulator1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply + portable_ops_inst pre pre1 acc + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), accumulator1))) + : Result _) = _ + rw [h_idx_secret] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_u] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_acc1_eq] + rfl + apply triple_of_ok_fc h_body + -- (6) Discharge the step_post. + show S1LoopFC.step_post secret_as_ntt u_as_ntt acc_init k + (.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), acc1)) + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + -- (7) Re-establish `loop_inv` at s_iter (= k+1). + show (S1LoopFC.loop_inv secret_as_ntt u_as_ntt acc_init s_iter acc1).holds + unfold S1LoopFC.loop_inv + have h_inv_pure : + (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = (List.range s_iter.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (secret_as_ntt.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (u_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + ∧ (∀ n : Nat, n < 256 → + (acc1.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + s_iter.val * 2^25) := by + refine ⟨?_, ?_⟩ + · -- (a) Accumulator characterization at s_iter = k+1. + intro j hj ℓ hℓ + have h_step_acc : + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (t_secret.coefficients.val[j]!)) + (lift_chunk_mont (t_u.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) := by + have := h_acc1_post + unfold accumulating_ntt_multiply_poly_post at this + exact this j hj ℓ hℓ + have h_ih := h_inv_acc j hj ℓ hℓ + rw [h_step_acc, h_ih] + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + rw [List.range_succ, List.foldl_append] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((List.range k.val).foldl _ _) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (t_secret.coefficients.val[j]!)) + (lift_chunk_mont (t_u.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) + = (List.foldl _ ((List.range k.val).foldl _ _) [k.val]) + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((List.range k.val).foldl _ _) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (secret_as_ntt.val[k.val]!.coefficients.val[j]!)) + (lift_chunk_mont (u_as_ntt.val[k.val]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) + rfl + · -- (b) Bound: ≤ acc_init[n] + s_iter.val * 2^25. + intro n hn + have h_acc1_bnd_n := h_acc1_bnd_rel ⟨n, hn⟩ + have h_acc1_bnd_n' : (acc1.val[n]!).val.natAbs ≤ (acc.val[n]!).val.natAbs + 2^25 := + h_acc1_bnd_n + have h_inv_n := h_inv_acc_bnd n hn + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + have h_arith : (k.val + 1) * 2^25 = k.val * 2^25 + 2^25 := by ring + rw [h_arith] + linarith [h_acc1_bnd_n', h_inv_n] + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ K, done. + have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + libcrux_iot_ml_kem.matrix.compute_message_loop.body + (vectortraitsOperationsInst := portable_ops_inst) secret_as_ntt u_as_ntt + { start := k, «end» := K } acc + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.matrix.compute_message_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show S1LoopFC.step_post secret_as_ntt u_as_ntt acc_init k (.done acc) + show (S1LoopFC.loop_inv secret_as_ntt u_as_ntt acc_init K acc).holds + unfold S1LoopFC.loop_inv + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (secret_as_ntt.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (u_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + K.val * 2^25) := by + refine ⟨?_, ?_⟩ + · intro j hj ℓ hℓ + have h_eq := h_inv_acc j hj ℓ hℓ + have h_rng : (List.range k.val) = (List.range K.val) := by rw [hk_eq] + rw [h_rng] at h_eq + exact h_eq + · intro n hn + have h_b := h_inv_acc_bnd n hn + have h_arith : k.val * 2^25 = K.val * 2^25 := by rw [hk_eq] + rw [h_arith] at h_b + exact h_b + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +/-- L7.4 S1 — `matrix.compute_message_loop`: the message-accumulation loop. + Iterates over `i ∈ [0, K)`, accumulating column-i's contribution + `secret_as_ntt[i] ⊛ u_as_ntt[i]` to the single I32[256] accumulator via + `accumulating_ntt_multiply` (NO cache, NO matrix). The cache-free analog + of `compute_As_plus_e_loop0`. + + POST: `loop_inv` holds at k = K, i.e. for all (j, ℓ) ∈ [0, 16)²: + `mont_reduce_pure (lift_fe_int acc2[16j+ℓ].val)` equals the K-column + canonical-form sum of `ntt_multiply_pure_no_acc` outputs over + `(secret_as_ntt[c], u_as_ntt[c])`, starting from the initial + accumulator's `mont_reduce_pure` lift, AND the running bound + `≤ accumulator[n] + K·2^25`. + + PRE: the standard 16×16 bound (3328) on `secret_as_ntt`/`u_as_ntt` + entries plus the additive accumulator BUDGET + `(accumulator[n]).val.natAbs + K·2^25 ≤ 2^30`. The budget is consumed by + the per-column forward dep (`accumulating_ntt_multiply_poly_fc`, PRE + `≤ 2^30`) at every iteration: the running accumulator satisfies + `acc[n] ≤ accumulator[n] + k·2^25 ≤ accumulator[n] + K·2^25 ≤ 2^30`. + + Mirrors `compute_As_plus_e_loop1_loop0_fc` cache-free, + with two source arrays indexed directly by the column. -/ +@[spec] +theorem compute_message_loop_fc + {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (accumulator : Std.Array Std.I32 256#usize) + (h_secret_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((secret_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_u_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((u_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, + (accumulator.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_message_loop + (vectortraitsOperationsInst := portable_ops_inst) + { start := 0#usize, «end» := K } secret_as_ntt u_as_ntt accumulator + ⦃ ⇓ acc2 => ⌜ (S1LoopFC.loop_inv secret_as_ntt u_as_ntt accumulator K acc2).holds ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.matrix.compute_message_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.matrix.compute_message_loop.body + (vectortraitsOperationsInst := portable_ops_inst) secret_as_ntt u_as_ntt + iter1 acc1) + (β := S1LoopFC.Acc) + accumulator + 0#usize K + (fun k acc => S1LoopFC.loop_inv secret_as_ntt u_as_ntt accumulator k acc) + (by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0]; exact Nat.zero_le _) + (by + -- Base case at k = 0. + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_⟩ + · intro j hj ℓ hℓ + show Spec.mont_reduce_pure _ + = (List.range (0#usize : Std.Usize).val).foldl _ _ + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] + show Spec.mont_reduce_pure _ = (List.range 0).foldl _ _ + simp [List.range_zero, List.foldl_nil] + · intro n _; have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0']; omega) + ?_) + · -- Post entailment: the final invariant holds at K. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (S1LoopFC.loop_inv secret_as_ntt u_as_ntt accumulator K r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + show (S1LoopFC.loop_inv secret_as_ntt u_as_ntt accumulator K r).holds + exact h_inv_holds + · -- Step entailment. + intro acc k _h_ge h_le hinv + have h_step := compute_message_loop_step_lemma_fc + secret_as_ntt u_as_ntt accumulator h_secret_bnd h_u_bnd h_acc_bnd acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : S1LoopFC.step_post secret_as_ntt u_as_ntt accumulator k + (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [S1LoopFC.step_post] using hP + · have hP : S1LoopFC.step_post secret_as_ntt u_as_ntt accumulator k + (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [S1LoopFC.step_post] using hP + +end L7_4_irreducible + +/-! ## A — the acc-bridge ("crux"): R-factor reconciliation. + + Relates the hacspec `multiply_vectors` (on `lift_vec`-lifted inputs) to + the loop accumulator's reduced value, scaled by `R = 2285` + (`multiply_vectors = 2285 · result1`, where `result1 = reducing(acc2)`). + The RHS `mont_strip_pure (poly_reducing(to_slice acc2))` equals + `lift_poly result1` via the proven `reducing_from_i32_array_fc` POST + (in the `lift_poly_mont` domain) composed with + `Impl.mont_strip_lift_poly_mont_eq_lift_poly`. -/ + +/-- Per-lane R-factor reconciliation. + + For any 256-FE array `p` and lane `j < 256`, + `scaleZ 2285 (mont_strip_pure p)` and + `mul_pure (p[j]) (mul_pure 1353 1353)` agree in `ZMod 3329`. The + factor identity unpacks as: the `mont_strip` + factor `zmodOfFE (lift_fe_mont 1353) = 1353·169 = 2285 = R`, the + `multiply_ntts`-canonical factor `mul_pure 1353 1353 = 2285² = 1353`, + and `2285 · 2285 = 1353` in `ZMod 3329`). + + This bridges the hacspec `multiply_vectors` lane + (`= mul_pure (loop_inv-foldl-sum) (mul_pure 1353 1353)`, via + `Spec.multiply_ntts_pure_eq_chunked_no_acc` + `foldl_add_mul_distrib`) + to `scaleZ 2285 (mont_strip_pure (poly_reducing(acc2)))` lane + (`= mul_pure (loop_inv-foldl-sum) (lift_fe_mont 1353)` then `·2285`), + where the shared foldl-sum is exactly S1's `S1LoopFC.loop_inv` + characterization. -/ +theorem compute_message_recon_lane + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) : + zmodOfFE ((scaleZ 2285 (Impl.mont_strip_pure p)).val[j]!) + = zmodOfFE (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (p.val[j]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (1353#i16 : Std.I16)) (lift_fe_mont (1353#i16 : Std.I16)))) := by + rw [scaleZ_lane 2285 _ j hj] + unfold Impl.mont_strip_pure + have hms : ((Std.Array.make 256#usize + ((List.range 256).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (p.val[i]!) (lift_fe_mont (1353#i16 : Std.I16)))) (by simp)).val[j]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (p.val[j]!) (lift_fe_mont (1353#i16 : Std.I16)) := by + show ((List.range 256).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (p.val[i]!) (lift_fe_mont (1353#i16 : Std.I16))))[j]! = _ + rw [getElem!_pos _ j (by simp [hj])] + rw [List.getElem_map, List.getElem_range] + rw [hms] + rw [zmodOfFE_mul_pure, zmodOfFE_mul_pure, zmodOfFE_mul_pure, zmodOfFE_lift_fe_mont] + have h1353 : (((1353#i16 : Std.I16).val : Int) : ZMod 3329) = 1353 := by decide + rw [h1353] + have hc : (2285 : ZMod 3329) * (1353 * 169) = 1353 * 169 * (1353 * 169) := by decide + rw [show (2285 : ZMod 3329) * (zmodOfFE (p.val[j]!) * (1353 * 169)) + = zmodOfFE (p.val[j]!) * ((2285 : ZMod 3329) * (1353 * 169)) from by ring] + rw [hc] + +/-! ### `multiply_vectors` loop reduction (mirror of + `multiply_matrix_by_column_at_eq`,, but cache/matrix-free: + the two source vectors are indexed directly by the column `j`). -/ + +/-- Per-lane partial sum produced by the `multiply_vectors` loop at step `k`: + the `add_pure` foldl of the per-column `multiply_ntts_pure` lane values, + seeded at the zero FE. Mirrors `col_loop_lane_at_step` + but folds the raw `multiply_ntts_pure` lane (the loop body adds the + `multiply_ntts` product directly — no pre-applied canonical factor). -/ +private noncomputable def vec_loop_lane_at_step {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (k : Nat) (ℓ : Nat) : hacspec_ml_kem.parameters.FieldElement := + (List.range k).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.multiply_ntts_pure + (lift_poly secret_as_ntt.val[c]!) (lift_poly u_as_ntt.val[c]!)).val[ℓ]!)) + ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement) + +/-- The per-step `multiply_vectors` accumulator array: lane ℓ is + `vec_loop_lane_at_step ... k ℓ`. -/ +private noncomputable def vec_loop_result_at_step {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (k : Nat) : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + ⟨(List.range 256).map (fun ℓ => vec_loop_lane_at_step secret_as_ntt u_as_ntt k ℓ), + by simp [List.length_map, List.length_range]⟩ + +private theorem vec_loop_result_at_step_val_lane {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (k : Nat) (ℓ : Nat) (hℓ : ℓ < 256) : + (vec_loop_result_at_step secret_as_ntt u_as_ntt k).val[ℓ]! + = vec_loop_lane_at_step secret_as_ntt u_as_ntt k ℓ := by + unfold vec_loop_result_at_step + show ((List.range 256).map + (fun ℓ' => vec_loop_lane_at_step secret_as_ntt u_as_ntt k ℓ'))[ℓ]! = _ + rw [getElem!_pos _ ℓ (by simp [List.length_map, List.length_range, hℓ])] + rw [List.getElem_map, List.getElem_range] + +/-- Base case: at step 0, every lane is `⟨0#u16⟩`. -/ +private theorem vec_loop_lane_at_step_zero {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (ℓ : Nat) : + vec_loop_lane_at_step secret_as_ntt u_as_ntt 0 ℓ + = ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement) := by + unfold vec_loop_lane_at_step + rw [List.range_zero, List.foldl_nil] + +/-- Step lemma: one column iteration `add_polynomials result (multiply_ntts …)` + advances `vec_loop_result_at_step ... k` to `... (k+1)`, lane-wise. -/ +private theorem vec_loop_lane_at_step_succ {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (k : Nat) (ℓ : Nat) : + vec_loop_lane_at_step secret_as_ntt u_as_ntt (k + 1) ℓ + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (vec_loop_lane_at_step secret_as_ntt u_as_ntt k ℓ) + ((Spec.multiply_ntts_pure + (lift_poly secret_as_ntt.val[k]!) (lift_poly u_as_ntt.val[k]!)).val[ℓ]!) := by + unfold vec_loop_lane_at_step + rw [List.range_succ, List.foldl_append, List.foldl_cons, List.foldl_nil] + +set_option maxHeartbeats 16000000 in +set_option maxRecDepth 1000 in +/-- **.** `multiply_vectors (lift_vec s) (lift_vec u)` reduces to the + pure per-lane `add_pure`-foldl array `vec_loop_result_at_step ... K.val`. + Composes `multiply_ntts_eq_pure_array` + `matrix_add_polynomials_eq_ok` + through `loop_range_spec_usize`; mirrors `multiply_matrix_by_column_at_eq`. -/ +private theorem multiply_vectors_eq {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) : + hacspec_ml_kem.matrix.multiply_vectors (lift_vec secret_as_ntt) (lift_vec u_as_ntt) + = .ok (vec_loop_result_at_step secret_as_ntt u_as_ntt K.val) := by + unfold hacspec_ml_kem.matrix.multiply_vectors + unfold hacspec_ml_kem.parameters.FieldElement.new + simp only [bind_tc_ok] + -- Reduce the loop via `loop_range_spec_usize` with the invariant + -- `result = vec_loop_result_at_step ... k.val`. + have h_triple : ⦃ ⌜ True ⌝ ⦄ + hacspec_ml_kem.matrix.multiply_vectors_loop + ({ start := 0#usize, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize) + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) + (Std.Array.repeat (256#usize : Std.Usize) + ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement)) + ⦃ ⇓ r => ⌜ r = vec_loop_result_at_step secret_as_ntt u_as_ntt K.val ⌝ ⦄ := by + unfold hacspec_ml_kem.matrix.multiply_vectors_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize => + hacspec_ml_kem.matrix.multiply_vectors_loop.body + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) p.1 p.2) + (β := Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (Std.Array.repeat (256#usize : Std.Usize) + ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement)) + 0#usize K + (fun k result => pure (result = vec_loop_result_at_step secret_as_ntt u_as_ntt k.val)) + (Nat.zero_le _) + (by + -- Base: init = vec_loop_result_at_step ... 0. + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + apply Subtype.ext + rw [Std.Array.repeat_val] + unfold vec_loop_result_at_step + show List.replicate 256 _ = (List.range 256).map _ + apply List.ext_getElem + · rw [List.length_replicate, List.length_map, List.length_range] + intro n h_n_lhs _ + have h_n_lt : n < 256 := by + rw [List.length_replicate] at h_n_lhs; exact h_n_lhs + rw [List.getElem_replicate, List.getElem_map, List.getElem_range] + show _ = vec_loop_lane_at_step secret_as_ntt u_as_ntt 0 n + rw [vec_loop_lane_at_step_zero]) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r hh + have h_eq : (pure (r = vec_loop_result_at_step secret_as_ntt u_as_ntt K.val) + : Result Prop).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_eq + · -- Step. + intro acc k _h_ge h_le hinv + have h_acc_eq : acc = vec_loop_result_at_step secret_as_ntt u_as_ntt k.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hinv + subst h_acc_eq + unfold hacspec_ml_kem.matrix.multiply_vectors_loop.body + by_cases h_lt : k.val < K.val + · -- `Some k` branch. + have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, + ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- index_usize v1 k = lift_poly secret[k]; index_usize v2 k = lift_poly u[k]. + have h_lift_S_len : (lift_vec secret_as_ntt).length = K.val := Std.Array.length_eq _ + have h_lift_U_len : (lift_vec u_as_ntt).length = K.val := Std.Array.length_eq _ + set a1 : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + lift_poly secret_as_ntt.val[k.val]! with h_a1_def + set a2 : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + lift_poly u_as_ntt.val[k.val]! with h_a2_def + have h_lift_S_val : (lift_vec secret_as_ntt).val[k.val]! = a1 := by + rw [h_a1_def]; unfold lift_vec + show (secret_as_ntt.val.map lift_poly)[k.val]! = _ + have h_len_s : secret_as_ntt.val.length = K.val := Std.Array.length_eq _ + rw [getElem!_pos _ k.val (by rw [List.length_map, h_len_s]; exact h_lt)] + rw [List.getElem_map] + rw [show secret_as_ntt.val[k.val] = secret_as_ntt.val[k.val]! from + (getElem!_pos _ k.val (by rw [h_len_s]; exact h_lt)).symm] + have h_lift_U_val : (lift_vec u_as_ntt).val[k.val]! = a2 := by + rw [h_a2_def]; unfold lift_vec + show (u_as_ntt.val.map lift_poly)[k.val]! = _ + have h_len_u : u_as_ntt.val.length = K.val := Std.Array.length_eq _ + rw [getElem!_pos _ k.val (by rw [List.length_map, h_len_u]; exact h_lt)] + rw [List.getElem_map] + rw [show u_as_ntt.val[k.val] = u_as_ntt.val[k.val]! from + (getElem!_pos _ k.val (by rw [h_len_u]; exact h_lt)).symm] + have h_idx_a1 : Aeneas.Std.Array.index_usize (lift_vec secret_as_ntt) k = .ok a1 := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq (lift_vec secret_as_ntt) k + (by rw [h_lift_S_len]; exact h_lt) + rw [h_lift_S_val] at this; exact this + have h_idx_a2 : Aeneas.Std.Array.index_usize (lift_vec u_as_ntt) k = .ok a2 := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq (lift_vec u_as_ntt) k + (by rw [h_lift_U_len]; exact h_lt) + rw [h_lift_U_val] at this; exact this + have h_mult_eq : hacspec_ml_kem.ntt.multiply_ntts a1 a2 + = .ok (Spec.multiply_ntts_pure a1 a2) := by + unfold Spec.multiply_ntts_pure + rw [HelpersFC.multiply_ntts_eq_pure_array] + have h_add_eq := Stage4MatrixAddFC.matrix_add_polynomials_eq_ok + (vec_loop_result_at_step secret_as_ntt u_as_ntt k.val) + (Spec.multiply_ntts_pure a1 a2) + set new_acc : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + ⟨(List.range 256).map (fun n => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (vec_loop_result_at_step secret_as_ntt u_as_ntt k.val).val[n]! + (Spec.multiply_ntts_pure a1 a2).val[n]!), + by simp [List.length_map, List.length_range]⟩ with h_new_acc_def + have h_new_acc_eq : new_acc + = vec_loop_result_at_step secret_as_ntt u_as_ntt (k.val + 1) := by + unfold vec_loop_result_at_step + apply Subtype.ext + rw [h_new_acc_def] + apply List.map_congr_left + intro n hn_mem + have hn_lt : n < 256 := List.mem_range.mp hn_mem + rw [vec_loop_result_at_step_val_lane _ _ _ _ hn_lt] + rw [vec_loop_lane_at_step_succ] + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize => + hacspec_ml_kem.matrix.multiply_vectors_loop.body + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) p.1 p.2) + ({ start := k, «end» := K }, + vec_loop_result_at_step secret_as_ntt u_as_ntt k.val) + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + vec_loop_result_at_step secret_as_ntt u_as_ntt (k.val + 1))) := by + show hacspec_ml_kem.matrix.multiply_vectors_loop.body + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) + { start := k, «end» := K } + (vec_loop_result_at_step secret_as_ntt u_as_ntt k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_vectors_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let a ← Aeneas.Std.Array.index_usize (lift_vec secret_as_ntt) k + let a1' ← Aeneas.Std.Array.index_usize (lift_vec u_as_ntt) k + let product ← hacspec_ml_kem.ntt.multiply_ntts a a1' + let result1 ← hacspec_ml_kem.matrix.add_polynomials + (vec_loop_result_at_step secret_as_ntt u_as_ntt k.val) product + Aeneas.Std.Result.ok (ControlFlow.cont + (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), result1))) + : Result _) = _ + rw [h_idx_a1] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a2] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_mult_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_add_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [← h_new_acc_eq] + apply triple_of_ok_fc h_body + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + show (pure (vec_loop_result_at_step secret_as_ntt u_as_ntt (k.val + 1) + = vec_loop_result_at_step secret_as_ntt u_as_ntt s_iter.val) + : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hs_iter_val] + rfl + · -- `None` branch: k ≥ K, done. + have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize => + hacspec_ml_kem.matrix.multiply_vectors_loop.body + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) p.1 p.2) + ({ start := k, «end» := K }, + vec_loop_result_at_step secret_as_ntt u_as_ntt k.val) + = .ok (ControlFlow.done + (vec_loop_result_at_step secret_as_ntt u_as_ntt k.val)) := by + show hacspec_ml_kem.matrix.multiply_vectors_loop.body + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) + { start := k, «end» := K } + (vec_loop_result_at_step secret_as_ntt u_as_ntt k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_vectors_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show (pure (vec_loop_result_at_step secret_as_ntt u_as_ntt k.val + = vec_loop_result_at_step secret_as_ntt u_as_ntt K.val) + : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hk_eq] + rfl + obtain ⟨v, hv_eq, hv_post⟩ := triple_exists_ok_fc h_triple + rw [hv_eq, hv_post] + +/-! ### re-derived Helper 1 (`multiply_ntts_lane_eq_canonical_factor`). + + The chunk-lift bilinearity bridge `multiply_ntts_lane_eq_canonical_factor` and its two dependencies (`chunk_at_lift_poly_lane`, + `ntt_multiply_pure_no_acc_lane_scale`) are `private` to FCTargets. We + re-derive a public-callable copy here from public primitives + (`Spec.multiply_ntts_pure_eq_chunked_no_acc`, `zmodOfFE_{add,mul}_pure`, + `lift_fe_mont_mul_1353_eq_lift_fe`, `PortableVector_elements_length`). -/ + +/-- `Canonical x → x.val.val < 3329` and the canonical round-trip closer + (re-derived from public `Canonical`/`FIELD_MODULUS`). -/ +private theorem L7_4_Hlp.canon_lt + (x : hacspec_ml_kem.parameters.FieldElement) + (hx : libcrux_iot_ml_kem.Spec.Pure.Canonical x) : x.val.val < 3329 := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at hx + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at hx; exact hx + +private theorem L7_4_Hlp.feOfZMod_zmodOfFE_of_lt + (x : hacspec_ml_kem.parameters.FieldElement) (hx : x.val.val < 3329) : + feOfZMod (zmodOfFE x) = x := by + unfold feOfZMod zmodOfFE + have hzval : ((x.val.val : ZMod 3329)).val = x.val.val := ZMod.val_natCast_of_lt hx + rw [hzval] + have hsval : x.val.val < 2 ^ 16 := by + have h_p : (3329 : Nat) ≤ 2 ^ 16 := by decide + omega + have hsbv : BitVec.ofNat 16 x.val.val = x.val.bv := by + apply BitVec.eq_of_toNat_eq + rw [BitVec.toNat_ofNat] + show x.val.val % 2 ^ 16 = x.val.bv.toNat + rw [Nat.mod_eq_of_lt hsval]; rfl + show ({ val := ⟨BitVec.ofNat 16 x.val.val⟩ } : hacspec_ml_kem.parameters.FieldElement) = x + rw [hsbv] + +/-- Canonical equality closer: two FEs `< 3329` with equal `zmodOfFE` are equal. -/ +private theorem L7_4_Hlp.eq_of_zmod_lt + (s t : hacspec_ml_kem.parameters.FieldElement) + (hs : s.val.val < 3329) (ht : t.val.val < 3329) (heq : zmodOfFE s = zmodOfFE t) : + s = t := by + rw [← L7_4_Hlp.feOfZMod_zmodOfFE_of_lt s hs, + ← L7_4_Hlp.feOfZMod_zmodOfFE_of_lt t ht, heq] + +/-- Re-derived `ntt_multiply_pure_no_acc_val_q` (the projection is `rfl`-ish). -/ +private theorem L7_4_Hlp.no_acc_val_q + (a b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (zeta0 zeta1 zeta2 zeta3 : hacspec_ml_kem.parameters.FieldElement) + (q : Nat) (h_q : q < 16) : + (Spec.ntt_multiply_pure_no_acc a b zeta0 zeta1 zeta2 zeta3).val[q]! + = (let neg := libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure + let zeta_q : hacspec_ml_kem.parameters.FieldElement := + [zeta0, neg zeta0, zeta1, neg zeta1, + zeta2, neg zeta2, zeta3, neg zeta3][q / 2]! + if q % 2 = 0 then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + a.val[2 * (q / 2)]! b.val[2 * (q / 2)]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + a.val[2 * (q / 2) + 1]! b.val[2 * (q / 2) + 1]!) + zeta_q) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + a.val[2 * (q / 2)]! b.val[2 * (q / 2) + 1]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + a.val[2 * (q / 2) + 1]! b.val[2 * (q / 2)]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rw [show ∀ (l : List hacspec_ml_kem.parameters.FieldElement) + (h : l.length = (16#usize : Std.Usize).val), + (Std.Array.make 16#usize l h).val[q]! = l[q]! from fun _ _ => rfl, + List.getElem!_eq_getElem?_getD, List.getElem?_map, List.getElem?_range h_q, + Option.map_some, Option.getD_some] + +/-- Re-derived `ntt_multiply_pure_no_acc_lane_scale`: per-lane `c²` bilinearity. -/ +private theorem L7_4_Hlp.no_acc_lane_scale + (a am b bm : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (c : hacspec_ml_kem.parameters.FieldElement) + (h_a : ∀ k : Nat, k < 16 → a.val[k]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure am.val[k]! c) + (h_b : ∀ k : Nat, k < 16 → b.val[k]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure bm.val[k]! c) + (zeta0 zeta1 zeta2 zeta3 : hacspec_ml_kem.parameters.FieldElement) + (q : Nat) (h_q : q < 16) : + (Spec.ntt_multiply_pure_no_acc a b zeta0 zeta1 zeta2 zeta3).val[q]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.ntt_multiply_pure_no_acc am bm zeta0 zeta1 zeta2 zeta3).val[q]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure c c) := by + have h_2pi : 2 * (q / 2) < 16 := by omega + have h_2pi1 : 2 * (q / 2) + 1 < 16 := by omega + rw [L7_4_Hlp.no_acc_val_q a b _ _ _ _ q h_q, + L7_4_Hlp.no_acc_val_q am bm _ _ _ _ q h_q] + rw [h_a (2 * (q / 2)) h_2pi, h_a (2 * (q / 2) + 1) h_2pi1, + h_b (2 * (q / 2)) h_2pi, h_b (2 * (q / 2) + 1) h_2pi1] + rcases (show q % 2 = 0 ∨ q % 2 = 1 from by omega) with h_par | h_par + · rw [if_pos h_par, if_pos h_par] + apply L7_4_Hlp.eq_of_zmod_lt + · exact L7_4_Hlp.canon_lt _ (libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _) + · exact L7_4_Hlp.canon_lt _ (libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _) + · simp only [zmodOfFE_add_pure, zmodOfFE_mul_pure]; ring + · have h_par_ne : q % 2 ≠ 0 := by omega + rw [if_neg h_par_ne, if_neg h_par_ne] + apply L7_4_Hlp.eq_of_zmod_lt + · exact L7_4_Hlp.canon_lt _ (libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _) + · exact L7_4_Hlp.canon_lt _ (libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure _ _) + · simp only [zmodOfFE_add_pure, zmodOfFE_mul_pure]; ring + +set_option maxHeartbeats 1000000 in +/-- Re-derived `chunk_at_lift_poly_lane`. -/ +private theorem L7_4_Hlp.chunk_at_lift_poly_lane + (p : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (j : Nat) (h_j : j < 16) (q : Nat) (h_q : q < 16) : + (Spec.chunk_at (lift_poly p) j).val[q]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont p.coefficients.val[j]!).val[q]!) + (lift_fe_mont (1353#i16 : Std.I16)) := by + set x : Std.I16 := (p.coefficients.val[j]!).elements.val[q]! with hx_def + have h_elem_len : ((p.coefficients.val[j]!).elements.val).length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_mont : (lift_chunk_mont p.coefficients.val[j]!).val[q]! = lift_fe_mont x := by + unfold lift_chunk_mont + show (((p.coefficients.val[j]!).elements.val).map lift_fe_mont)[q]! = lift_fe_mont x + have h_len : (((p.coefficients.val[j]!).elements.val).map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_elem_len + rw [getElem!_pos _ q (by rw [h_len]; exact h_q)] + rw [List.getElem_map] + rw [show ((p.coefficients.val[j]!).elements.val)[q] + = ((p.coefficients.val[j]!).elements.val)[q]! from + (getElem!_pos _ q (by rw [h_elem_len]; exact h_q)).symm] + have h_plain : (Spec.chunk_at (lift_poly p) j).val[q]! = lift_fe x := by + unfold Spec.chunk_at + show ((List.range 16).map (fun j' => (lift_poly p).val[16 * j + j']!))[q]! = lift_fe x + have h_len_outer : ((List.range 16).map + (fun j' => (lift_poly p).val[16 * j + j']!)).length = 16 := by simp + rw [getElem!_pos _ q (by rw [h_len_outer]; exact h_q)] + rw [List.getElem_map, List.getElem_range] + have h_lane : 16 * j + q < 256 := by omega + unfold lift_poly + show ((List.range 256).map (fun n => + lift_fe (p.coefficients.val[n / 16]!).elements.val[n % 16]!))[16 * j + q]! = lift_fe x + have h_len_inner : ((List.range 256).map (fun n => + lift_fe (p.coefficients.val[n / 16]!).elements.val[n % 16]!)).length = 256 := by simp + rw [getElem!_pos _ (16 * j + q) (by rw [h_len_inner]; exact h_lane)] + rw [List.getElem_map, List.getElem_range] + have h_div : (16 * j + q) / 16 = j := by omega + have h_mod : (16 * j + q) % 16 = q := by omega + rw [h_div, h_mod] + rw [h_mont, h_plain] + rw [Impl.lift_fe_mont_mul_1353_eq_lift_fe] + +/-- Re-derived Helper 1: `multiply_ntts_pure (lift_poly a)(lift_poly b)` lane ℓ + equals the chunk-lift `no_acc` lane scaled by `(lift_fe_mont 1353)²`. -/ +private theorem L7_4_Hlp.multiply_ntts_lane_eq_canonical_factor + (a b : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (ℓ : Nat) (hℓ : ℓ < 256) : + (Spec.multiply_ntts_pure (lift_poly a) (lift_poly b)).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont a.coefficients.val[ℓ / 16]!) + (lift_chunk_mont b.coefficients.val[ℓ / 16]!) + (Spec.zeta_at (64 + 4 * (ℓ / 16))) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3))).val[ℓ % 16]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (1353#i16 : Std.I16)) + (lift_fe_mont (1353#i16 : Std.I16))) := by + have h_div_lt : ℓ / 16 < 16 := by omega + have h_mod_lt : ℓ % 16 < 16 := Nat.mod_lt _ (by decide) + rw [Spec.multiply_ntts_pure_eq_chunked_no_acc] + unfold Spec.flatten_chunks + show ((List.range 256).map (fun j => + (((List.range 16).map (fun j' => + Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at (lift_poly a) j') (Spec.chunk_at (lift_poly b) j') + (Spec.zeta_at (64 + 4 * j')) (Spec.zeta_at (64 + 4 * j' + 1)) + (Spec.zeta_at (64 + 4 * j' + 2)) (Spec.zeta_at (64 + 4 * j' + 3))) + )[j / 16]!).val[j % 16]!))[ℓ]! = _ + rw [getElem!_pos _ ℓ (by simp [List.length_map, List.length_range, hℓ])] + rw [List.getElem_map, List.getElem_range] + rw [getElem!_pos _ (ℓ / 16) (by simp [List.length_map, List.length_range, h_div_lt])] + rw [List.getElem_map, List.getElem_range] + exact L7_4_Hlp.no_acc_lane_scale + (Spec.chunk_at (lift_poly a) (ℓ / 16)) (lift_chunk_mont a.coefficients.val[ℓ / 16]!) + (Spec.chunk_at (lift_poly b) (ℓ / 16)) (lift_chunk_mont b.coefficients.val[ℓ / 16]!) + (lift_fe_mont (1353#i16 : Std.I16)) + (fun k h_k => L7_4_Hlp.chunk_at_lift_poly_lane a (ℓ / 16) h_div_lt k h_k) + (fun k h_k => L7_4_Hlp.chunk_at_lift_poly_lane b (ℓ / 16) h_div_lt k h_k) + (Spec.zeta_at (64 + 4 * (ℓ / 16))) (Spec.zeta_at (64 + 4 * (ℓ / 16) + 1)) + (Spec.zeta_at (64 + 4 * (ℓ / 16) + 2)) (Spec.zeta_at (64 + 4 * (ℓ / 16) + 3)) + (ℓ % 16) h_mod_lt + +/-! ### local canonical-determination helper + per-lane bridge. -/ + +/-- Local copy of `eq_of_zmod_lane_canon` (private in ComputeMessage.Hacspec): + two canonical 256-FE arrays that agree on every `zmodOfFE` lane are equal. -/ +private theorem eq_of_zmod_lane_canon_local + (u v : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hcu : ∀ j : Nat, j < 256 → libcrux_iot_ml_kem.Spec.Pure.Canonical (u.val[j]!)) + (hcv : ∀ j : Nat, j < 256 → libcrux_iot_ml_kem.Spec.Pure.Canonical (v.val[j]!)) + (hz : ∀ j : Nat, j < 256 → zmodOfFE (u.val[j]!) = zmodOfFE (v.val[j]!)) : + u = v := by + have h_canon_round : ∀ (x : hacspec_ml_kem.parameters.FieldElement), + libcrux_iot_ml_kem.Spec.Pure.Canonical x → feOfZMod (zmodOfFE x) = x := + fun x hx => L7_4_Hlp.feOfZMod_zmodOfFE_of_lt x (L7_4_Hlp.canon_lt x hx) + apply Subtype.ext + apply List.ext_getElem + · rw [Aeneas.Std.Array.length_eq u, Aeneas.Std.Array.length_eq v] + · intro j hj1 _hj2 + have hj : j < 256 := by rw [Aeneas.Std.Array.length_eq u] at hj1; simpa using hj1 + have heq : u.val[j]! = v.val[j]! := by + rw [← h_canon_round (u.val[j]!) (hcu j hj), + ← h_canon_round (v.val[j]!) (hcv j hj), hz j hj] + have huj : u.val[j]! = u.val[j] := + getElem!_pos u.val j (by rw [Aeneas.Std.Array.length_eq u]; exact hj) + have hvj : v.val[j]! = v.val[j] := + getElem!_pos v.val j (by rw [Aeneas.Std.Array.length_eq v]; exact hj) + rw [← huj, ← hvj]; exact heq + +/-- Foldl congruence in `ZMod 3329` across a common multiplicative `factor`: + if the seeds and every per-step summand of two `add_pure`-foldls relate by + `zmodOfFE (aₓ) = factor * zmodOfFE (bₓ)`, then so do the foldl results. -/ +private theorem zmodOfFE_foldl_add_pure_factor {α : Type} (L : List α) + (fa fb : α → hacspec_ml_kem.parameters.FieldElement) + (seedA seedB : hacspec_ml_kem.parameters.FieldElement) (factor : ZMod 3329) + (h_seed : zmodOfFE seedA = factor * zmodOfFE seedB) + (h_step : ∀ c ∈ L, zmodOfFE (fa c) = factor * zmodOfFE (fb c)) : + zmodOfFE (L.foldl + (fun s c => libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s (fa c)) seedA) + = factor * zmodOfFE (L.foldl + (fun s c => libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s (fb c)) seedB) := by + induction L generalizing seedA seedB with + | nil => simpa using h_seed + | cons h t ih => + simp only [List.foldl_cons] + apply ih + · rw [zmodOfFE_add_pure, zmodOfFE_add_pure, h_seed, h_step h (by simp)] + ring + · intro c hc; exact h_step c (by simp [hc]) + +/-- Any `add_pure`-foldl over a canonical seed is canonical (the `nil` case is + the seed; every `cons` step is an `add_pure`, hence canonical). -/ +private theorem foldl_add_pure_canonical {α : Type} (L : List α) + (f : α → hacspec_ml_kem.parameters.FieldElement) + (seed : hacspec_ml_kem.parameters.FieldElement) + (h_seed : libcrux_iot_ml_kem.Spec.Pure.Canonical seed) : + libcrux_iot_ml_kem.Spec.Pure.Canonical + (L.foldl + (fun s c => libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s (f c)) seed) := by + induction L generalizing seed with + | nil => simpa using h_seed + | cons h t ih => + simp only [List.foldl_cons] + exact ih _ (libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _) + +private theorem zero_fe_canonical : + libcrux_iot_ml_kem.Spec.Pure.Canonical + ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement) := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq]; decide + +set_option maxHeartbeats 1000000 in +/-- **A — acc-bridge / "crux" (CLOSED).** + + Relates the hacspec `multiply_vectors` (on `lift_vec`-lifts) to the loop + accumulator's reduced value, scaled by `R = 2285`. Statement numerically + Factor 2285 = R. RHS kept in the + `scaleZ 2285 (mont_strip ∘ poly_reducing ∘ to_slice acc2)` form per the + FC-glue requirement. + + Proof: reduces `multiply_vectors` to the per-lane `add_pure`-foldl + array `vec_loop_result_at_step`; both sides are canonical, so equality + follows lane-wise in `ZMod 3329`. The per-lane match composes: + `Spec.multiply_ntts_pure_eq_chunked_no_acc`/`multiply_ntts_lane_eq_canonical_factor` + (LHS lane = `mul_pure (no_acc-lane) (1353·1353)`), `h_char` (the S1 loop + invariant: `poly_reducing(acc2)` lane = the same `no_acc` foldl-sum), and + the `2285 = R` factor identity (`zmodOfFE (lift_fe_mont 1353) = 1353·169`, + `2285·(1353·169·169·169) ≡ 1353·169·169` collapse), all in `ZMod 3329`. -/ +theorem compute_message_acc_bridge {K : Std.Usize} + (secret_as_ntt u_as_ntt : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Std.Array Std.I32 256#usize) + (acc2 : Std.Array Std.I32 256#usize) + (h_acc_init_zero : ∀ n : Nat, n < 256 → (acc_init.val[n]!).val = 0) + (_h_secret_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((secret_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (_h_u_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((u_as_ntt.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_char : (S1LoopFC.loop_inv secret_as_ntt u_as_ntt acc_init K acc2).holds) : + hacspec_ml_kem.matrix.multiply_vectors (lift_vec secret_as_ntt) (lift_vec u_as_ntt) + = .ok (scaleZ 2285 (Impl.mont_strip_pure + (Spec.poly_reducing_from_i32_array_pure (Aeneas.Std.Array.to_slice acc2)))) := by + rw [multiply_vectors_eq secret_as_ntt u_as_ntt] + -- Destructure `h_char`'s conjunct (1): the per-lane `no_acc` foldl characterization. + obtain ⟨h_inv_acc, _⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_char + -- Abbreviations. + set P : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Impl.mont_strip_pure + (Spec.poly_reducing_from_i32_array_pure (Aeneas.Std.Array.to_slice acc2)) with hP_def + congr 1 + -- Equality of two canonical 256-arrays by per-lane `zmodOfFE`. + apply eq_of_zmod_lane_canon_local + · -- LHS lanes canonical (`add_pure`-foldl ⇒ canonical via the cons step; + -- the empty foldl is the seed `⟨0#u16⟩`, also canonical). + intro n hn + rw [vec_loop_result_at_step_val_lane _ _ _ _ hn] + unfold vec_loop_lane_at_step + exact foldl_add_pure_canonical _ _ _ zero_fe_canonical + · -- RHS lanes canonical (`scaleZ` lane = `feOfZMod _`, always `< 3329`). + intro n hn + unfold scaleZ + show libcrux_iot_ml_kem.Spec.Pure.Canonical + (((List.range 256).map (fun j => feOfZMod ((2285 : ZMod 3329) * zmodOfFE (P.val[j]!))))[n]!) + rw [getElem!_pos _ n (by simp [List.length_map, List.length_range, hn])] + rw [List.getElem_map, List.getElem_range] + unfold feOfZMod libcrux_iot_ml_kem.Spec.Pure.Canonical + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] + show (⟨BitVec.ofNat 16 (((2285 : ZMod 3329) * zmodOfFE (P.val[n]!)).val)⟩ + : Std.U16).val < 3329 + have h_lt16 : (((2285 : ZMod 3329) * zmodOfFE (P.val[n]!)).val) < 2 ^ 16 := by + have := ZMod.val_lt ((2285 : ZMod 3329) * zmodOfFE (P.val[n]!)) + omega + show (BitVec.ofNat 16 (((2285 : ZMod 3329) * zmodOfFE (P.val[n]!)).val)).toNat < 3329 + rw [BitVec.toNat_ofNat, Nat.mod_eq_of_lt h_lt16] + exact ZMod.val_lt _ + · -- Per-lane `zmodOfFE` equality. + intro n hn + -- RHS lane zmod = 2285 * zmodOfFE (mont_strip lane). + rw [hP_def, scaleZ_lane 2285 _ n hn] + -- mont_strip lane = mul_pure (poly_reducing lane) (lift_fe_mont 1353). + have h_ms : (Impl.mont_strip_pure + (Spec.poly_reducing_from_i32_array_pure (Aeneas.Std.Array.to_slice acc2))).val[n]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.poly_reducing_from_i32_array_pure (Aeneas.Std.Array.to_slice acc2)).val[n]!) + (lift_fe_mont (1353#i16 : Std.I16)) := by + unfold Impl.mont_strip_pure + show ((List.range 256).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Spec.poly_reducing_from_i32_array_pure (Aeneas.Std.Array.to_slice acc2)).val[i]!) + (lift_fe_mont (1353#i16 : Std.I16))))[n]! = _ + rw [getElem!_pos _ n (by simp [List.length_map, List.length_range, hn])] + rw [List.getElem_map, List.getElem_range] + rw [h_ms, zmodOfFE_mul_pure, zmodOfFE_lift_fe_mont] + have h1353 : (((1353#i16 : Std.I16).val : ZMod 3329)) = 1353 := by decide + rw [h1353] + -- poly_reducing lane = mont_reduce_pure (lift_fe_int acc2[n].val) (via to_slice .val). + have h_pr : (Spec.poly_reducing_from_i32_array_pure (Aeneas.Std.Array.to_slice acc2)).val[n]! + = Spec.mont_reduce_pure (lift_fe_int (acc2.val[n]!).val) := by + unfold Spec.poly_reducing_from_i32_array_pure + show ((List.range 256).map (fun i => + Spec.mont_reduce_pure (lift_fe_int ((Aeneas.Std.Array.to_slice acc2).val[i]!).val)))[n]! + = _ + rw [getElem!_pos _ n (by simp [List.length_map, List.length_range, hn])] + rw [List.getElem_map, List.getElem_range] + rw [Aeneas.Std.Array.val_to_slice] + rw [h_pr] + -- LHS lane = vec_loop foldl over `multiply_ntts_pure` lanes. + rw [vec_loop_result_at_step_val_lane _ _ _ _ hn] + unfold vec_loop_lane_at_step + -- Per-lane (n = 16*(n/16) + n%16) shorthands. + set j := n / 16 with hj_def + set ℓ := n % 16 with hℓ_def + have hj_lt : j < 16 := by omega + have hℓ_lt : ℓ < 16 := Nat.mod_lt _ (by decide) + have hn_eq : n = 16 * j + ℓ := by rw [hj_def, hℓ_def]; omega + -- `h_char` at (j, ℓ): the `no_acc` foldl equals `mont_reduce_pure (lift_fe_int acc2[n])`. + have h_char_jℓ := h_inv_acc j hj_lt ℓ hℓ_lt + rw [show 16 * j + ℓ = n from hn_eq.symm] at h_char_jℓ + -- Rewrite RHS foldl-FE using h_char. + rw [h_char_jℓ] + -- Now: zmodOfFE (LHS foldl over multiply_ntts_pure lanes) + -- = 2285 * (zmodOfFE (RHS foldl over no_acc lanes) * (1353 * 169)). + -- Apply the factor-foldl congruence with factor = (1353*169)^2 = 1353. + set fb : Nat → hacspec_ml_kem.parameters.FieldElement := fun c => + (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (secret_as_ntt.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (u_as_ntt.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]! with hfb_def + set fa : Nat → hacspec_ml_kem.parameters.FieldElement := fun c => + (Spec.multiply_ntts_pure + (lift_poly secret_as_ntt.val[c]!) (lift_poly u_as_ntt.val[c]!)).val[n]! with hfa_def + set seedB : hacspec_ml_kem.parameters.FieldElement := + Spec.mont_reduce_pure (lift_fe_int (acc_init.val[n]!).val) with hseedB_def + have h_factor_congr : + zmodOfFE ((List.range K.val).foldl + (fun s c => libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s (fa c)) + ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement)) + = (1353 : ZMod 3329) + * zmodOfFE ((List.range K.val).foldl + (fun s c => libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s (fb c)) + seedB) := by + apply zmodOfFE_foldl_add_pure_factor + · -- seeds: both have zmod 0. + rw [hseedB_def] + have h0 : (acc_init.val[n]!).val = 0 := h_acc_init_zero n hn + rw [h0] + -- zmodOfFE ⟨0#u16⟩ = 0 and 1353 * zmodOfFE (mont_reduce_pure (lift_fe_int 0)) = 0. + have hL : zmodOfFE ({ val := 0#u16 } : hacspec_ml_kem.parameters.FieldElement) = 0 := by + unfold zmodOfFE; decide + have hR : zmodOfFE (Spec.mont_reduce_pure (lift_fe_int (0 : Int))) = 0 := by + unfold Spec.mont_reduce_pure lift_fe_int + rw [zmodOfFE_feOfZMod, zmodOfFE_feOfZMod] + push_cast; ring + rw [hL, hR]; ring + · -- per-step factor: zmodOfFE (fa c) = 1353 * zmodOfFE (fb c). + intro c _ + simp only [hfa_def, hfb_def] + -- Helper 1: multiply_ntts_pure lane n = mul_pure (no_acc-lane at ℓ) (1353²). + rw [show n = 16 * j + ℓ from hn_eq] + have h_lane := L7_4_Hlp.multiply_ntts_lane_eq_canonical_factor + secret_as_ntt.val[c]! u_as_ntt.val[c]! (16 * j + ℓ) (by omega) + have hdiv : (16 * j + ℓ) / 16 = j := by omega + have hmod : (16 * j + ℓ) % 16 = ℓ := by omega + rw [hdiv, hmod] at h_lane + rw [h_lane, zmodOfFE_mul_pure, zmodOfFE_mul_pure, zmodOfFE_lift_fe_mont] + have h1353' : (((1353#i16 : Std.I16).val : ZMod 3329)) = 1353 := by decide + rw [h1353'] + -- (no_acc-lane) * (1353*169 * (1353*169)) = 1353 * (no_acc-lane). + have hfac : ((1353 : ZMod 3329) * 169) * ((1353 : ZMod 3329) * 169) = 1353 := by decide + rw [show (zmodOfFE _ * ((1353 : ZMod 3329) * 169 * (1353 * 169))) + = (1353 : ZMod 3329) * 169 * (1353 * 169) * zmodOfFE _ from by ring] + rw [hfac] + rw [h_factor_congr] + -- Goal: 1353 * zmodOfFE (foldl_fb) = 2285 * (zmodOfFE (foldl_fb) * (1353 * 169)). + -- Both sides equal 1353 * zmodOfFE(foldl_fb) since 2285 * (1353*169) = 2285*2285 = 1353. + have hRfac : (2285 : ZMod 3329) * ((1353 : ZMod 3329) * 169) = 1353 := by decide + rw [show (2285 : ZMod 3329) * (zmodOfFE _ * (1353 * 169)) + = (2285 : ZMod 3329) * (1353 * 169) * zmodOfFE _ from by ring] + rw [hRfac] + +end libcrux_iot_ml_kem.Matrix.ComputeMessage.Impl \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeRingElementV/FC.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeRingElementV/FC.lean new file mode 100644 index 00000000..c47bba87 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeRingElementV/FC.lean @@ -0,0 +1,330 @@ +/- + # `Matrix/ComputeRingElementV/FC.lean` — L7.3 FC theorem glue. + + Houses the L7.3 FC theorem `compute_ring_element_v_fc`, gluing the + direct decomposition (impl walk via `triple_*_ok_fc` + the A/C/B/ + compose/glue/D′ chain). Mirrors L7.4 `compute_message_fc` + near-verbatim, swapping the loop (chunks-exact deserialize loop + instead of the sampled use-cache loop) and the tail (two + `add_polynomials` instead of a single `subtract`). + + POST: + `hacspec_ml_kem.matrix.compute_ring_element_v + (lift_t_as_ntt_from_public_key public_key K) (lift_vec_slice r_as_ntt K) + (lift_poly error_2) (lift_poly message) = .ok (lift_poly p.2.1)`. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Sampling +import LibcruxIotMlKem.Serialize +import LibcruxIotMlKem.Matrix.ComputeRingElementV.Impl +import LibcruxIotMlKem.Matrix.ComputeMessage.Hacspec +import LibcruxIotMlKem.Matrix.ComputeVectorU.Hacspec +import LibcruxIotMlKem.Matrix.ComputeRingElementV.Hacspec + +namespace libcrux_iot_ml_kem.Matrix.ComputeRingElementV.FC +open libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges libcrux_iot_ml_kem.Matrix.ComputeMessage.Hacspec libcrux_iot_ml_kem.Matrix.ComputeMessage.Impl libcrux_iot_ml_kem.Matrix.ComputeRingElementV.Hacspec libcrux_iot_ml_kem.Matrix.ComputeRingElementV.Impl libcrux_iot_ml_kem.Matrix.ComputeVectorU.Hacspec +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Sampling libcrux_iot_ml_kem.Serialize libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +set_option mvcgen.warning false +set_option linter.unusedVariables false + +/-- Local copy of the `private triple_exists_ok_fc` helper. -/ +private theorem triple_exists_ok_fc {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- Local copy of the `private triple_of_ok_fc` helper. -/ +private theorem triple_of_ok_fc {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +/-- `scaleZ c p` lanes are canonical. -/ +private theorem scaleZ_canon (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) : + libcrux_iot_ml_kem.Spec.Pure.Canonical ((scaleZ c p).val[j]!) := by + unfold scaleZ + show libcrux_iot_ml_kem.Spec.Pure.Canonical + (((List.range 256).map (fun k => feOfZMod (c * zmodOfFE (p.val[k]!))))[j]!) + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical feOfZMod + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] + show (BitVec.ofNat 16 ((c * zmodOfFE (p.val[j]!)).val)).toNat < 3329 + set z := c * zmodOfFE (p.val[j]!) + have h_lt16 : z.val < 2 ^ 16 := by have := ZMod.val_lt z; omega + rw [BitVec.toNat_ofNat, Nat.mod_eq_of_lt h_lt16] + exact ZMod.val_lt _ + +/-- `lift_poly x` lanes are canonical. -/ +private theorem lift_poly_canon + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (j : Nat) (hj : j < 256) : + libcrux_iot_ml_kem.Spec.Pure.Canonical ((lift_poly re).val[j]!) := by + unfold lift_poly + show libcrux_iot_ml_kem.Spec.Pure.Canonical + (((List.range 256).map (fun i => + lift_fe (re.coefficients.val[i / 16]!).elements.val[i % 16]!))[j]!) + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + unfold lift_fe libcrux_iot_ml_kem.Spec.Pure.Canonical feOfZMod + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] + show (⟨BitVec.ofNat 16 ((i16_to_spec_fe_plain + (re.coefficients.val[j / 16]!).elements.val[j % 16]!).val)⟩ : Std.U16).val < 3329 + show (BitVec.ofNat 16 ((i16_to_spec_fe_plain + (re.coefficients.val[j / 16]!).elements.val[j % 16]!).val)).toNat < 3329 + set z := i16_to_spec_fe_plain (re.coefficients.val[j / 16]!).elements.val[j % 16]! + have h_lt16 : z.val < 2 ^ 16 := by + have := ZMod.val_lt z; omega + rw [BitVec.toNat_ofNat, Nat.mod_eq_of_lt h_lt16] + exact ZMod.val_lt _ + +/-! ## L7.3 FC theorem (capstone). -/ +set_option maxHeartbeats 4000000 in +@[spec] +theorem compute_ring_element_v_fc + (K : Std.Usize) + (public_key : Slice Std.U8) + (t_as_ntt_entry : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (r_as_ntt : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (error_2 message result : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (cache : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (accumulator : Std.Array Std.I32 256#usize) + (hK : K.val ≤ 4) + (h_pk_len : public_key.length = K.val * 384) + (h_r_len : r_as_ntt.length = K.val) + (h_cache_len : cache.length = K.val) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_cache_char : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (r_as_ntt.val[c]!) (cache.val[c]!)) + (h_acc_zero : ∀ n : Nat, n < 256 → (accumulator.val[n]!).val = 0) + (h_error_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((error_2.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 3328) + (h_message_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((message.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_ring_element_v + K (vectortraitsOperationsInst := portable_ops_inst) + public_key t_as_ntt_entry r_as_ntt error_2 message result scratch + cache accumulator + ⦃ ⇓ p => ⌜ hacspec_ml_kem.matrix.compute_ring_element_v + (lift_t_as_ntt_from_public_key public_key K) + (lift_vec_slice r_as_ntt K) + (lift_poly error_2) (lift_poly message) + = .ok (lift_poly p.2.1) ⌝ ⦄ := by + -- r_arr : Array Poly K from r_as_ntt. + set r_arr : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K := + ⟨r_as_ntt.val, by rw [← h_r_len]⟩ with h_r_arr_def + have h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]! := by + intro c hc; rfl + -- Step 0: classify 0#i32 = .ok 0#i32; acc1 = repeat 256 (0#i32) (all-zero). + set acc1 : Std.Array Std.I32 256#usize := + Std.Array.repeat (256#usize : Std.Usize) (0#i32 : Std.I32) with h_acc1_def + have h_acc1_zero : ∀ n : Nat, n < 256 → (acc1.val[n]!).val = 0 := by + intro n hn + rw [h_acc1_def, Std.Array.repeat_val] + rw [getElem!_pos _ n (by rw [List.length_replicate]; exact hn)] + rw [List.getElem_replicate]; rfl + -- iter0 reduction: BYTES_PER_RING_ELEMENT = 384; chunks_exact + enumerate. + set iter0 : EnumCE := + { iter := { cs := 384#usize, elements := public_key }, count := 0#usize } with h_iter0_def + -- S1: run the chunks-exact deserialize loop; get acc2 with the loop invariant. + obtain ⟨⟨t_ent1, acc2⟩, h_loop_eq, h_char⟩ := triple_exists_ok_fc + (compute_ring_element_v_loop_fc K hK public_key h_pk_len t_as_ntt_entry + r_as_ntt cache r_arr h_r_len h_cache_len h_r_arr h_r_bnd h_cache_char + acc1 h_acc1_zero iter0 h_iter0_def) + -- Accumulator bound: acc2[n].natAbs ≤ K·2^25 ≤ 2^27 (from loop_inv conjunct 2). + have h_acc2_bnd : ∀ n : Nat, n < 256 → (acc2.val[n]!).val.natAbs ≤ 2^27 := by + intro n hn + obtain ⟨_, h_inv_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_char + have hb : (acc2.val[n]!).val.natAbs ≤ (acc1.val[n]!).val.natAbs + K.val * 2^25 := + h_inv_bnd n hn + have h0 : (acc1.val[n]!).val.natAbs = 0 := by rw [h_acc1_zero n hn]; rfl + rw [h0] at hb + have hK4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + have : (2 ^ 27 : Nat) = 4 * 2^25 := by norm_num + omega + -- reducing step: result1. + set s := Aeneas.Std.Array.to_slice acc2 with h_s_def + have h_s_len : s.length = 256 := by + rw [h_s_def, Aeneas.Std.Array.length_to_slice]; rfl + have h_s_bnd : ∀ i : Nat, i < 256 → (s.val[i]!).val.natAbs ≤ 2^16 * 3328 := by + intro i hi + rw [h_s_def, Aeneas.Std.Array.val_to_slice] + have := h_acc2_bnd i hi + have h27 : (2 ^ 27 : Nat) ≤ 2^16 * 3328 := by norm_num + omega + obtain ⟨result1, h_result1_eq, h_result1_mont, h_result1_lane_bnd⟩ := + triple_exists_ok_fc + (poly_reducing_from_i32_array_fc s result h_s_len h_s_bnd) + -- lift_poly result1 = mont_strip (poly_reducing s). + have h_result1_lift : lift_poly result1 + = Impl.mont_strip_pure (Spec.poly_reducing_from_i32_array_pure s) := by + rw [← h_result1_mont, Impl.mont_strip_lift_poly_mont_eq_lift_poly] + -- invert step. PRE ≤13312 from result1 ≤4993. + have h_result1_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((result1.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312 := by + intro chunk hchunk k hk + have := h_result1_lane_bnd chunk hchunk k hk + omega + obtain ⟨⟨result2, scratch1⟩, h_inv_eq, h_result2_lift, h_result2_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_montgomery_fc (K := K) result1 scratch h_result1_bnd) + dsimp only at h_inv_eq h_result2_lift h_result2_bnd + -- add-message-error step. PRE: result2 ≤32767, error_2+message ≤29439. + have h_result2_b_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((result2.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro chunk hchunk ℓ hℓ + have := h_result2_bnd chunk hchunk ℓ hℓ; omega + have h_sum_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + (((error_2.coefficients.val[chunk]!).elements.val[ℓ]!).val + + ((message.coefficients.val[chunk]!).elements.val[ℓ]!).val + : Int).natAbs ≤ 29439 := by + intro chunk hchunk ℓ hℓ + have he := h_error_bnd chunk hchunk ℓ hℓ + have hm := h_message_bnd chunk hchunk ℓ hℓ + omega + obtain ⟨⟨result3, scratch2⟩, h_add_eq, h_result3_lift⟩ := + triple_exists_ok_fc + (add_message_error_reduce_fc error_2 message result2 scratch1 + h_result2_b_bnd h_sum_bnd) + dsimp only at h_add_eq h_result3_lift + -- BYTES_PER_RING_ELEMENT reduces to 384 (both constants are `irreducible`). + have h_bpr : (libcrux_iot_ml_kem.constants.BYTES_PER_RING_ELEMENT : Result Std.Usize) + = .ok (384#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.constants.BYTES_PER_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.BITS_PER_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + -- 256 * 12 = 3072. + have hm_max : (256#usize : Std.Usize).val * (12#usize : Std.Usize).val ≤ Std.Usize.max := by + scalar_tac + obtain ⟨m, hm_eq, hm_v⟩ := Std.WP.spec_imp_exists (Std.Usize.mul_spec hm_max) + have hm : m = (3072#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + show m.val = (3072#usize : Std.Usize).val; rw [hm_v]; decide + rw [hm_eq, hm] + -- 3072 / 8 = 384. + obtain ⟨d, hd_eq, hd_v⟩ := Aeneas.Std.UScalar.div_spec (3072#usize : Std.Usize) + (by decide : ((8#usize : Std.Usize).val : Nat) ≠ 0) + have hd : d = (384#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + show d.val = (384#usize : Std.Usize).val; rw [hd_v]; decide + simp only [Aeneas.Std.bind_tc_ok, hd_eq, hd] + -- Reduce the impl do-block to `.ok (t_ent1, result3, scratch2, acc2)`. + apply triple_of_ok_fc + (v := (t_ent1, result3, scratch2, acc2)) + · unfold libcrux_iot_ml_kem.matrix.compute_ring_element_v + simp only [libcrux_secrets.traits.Classify.Blanket.classify, Aeneas.Std.lift, + Aeneas.Std.bind_tc_ok] + rw [show (Std.Array.repeat (256#usize : Std.Usize) (0#i32 : Std.I32)) = acc1 from rfl] + -- BYTES_PER_RING_ELEMENT = 384, chunks_exact + enumerate = iter0. + rw [h_bpr] + simp only [Aeneas.Std.bind_tc_ok] + rw [show (CoreModels.core.slice.Slice.chunks_exact public_key (384#usize : Std.Usize)) + = .ok { cs := 384#usize, elements := public_key } from rfl] + simp only [Aeneas.Std.bind_tc_ok] + rw [show (CoreModels.core.slice.iter.ChunksExact.Insts.CoreIterTraitsIteratorIteratorSharedASlice.enumerate + { cs := (384#usize : Std.Usize), elements := public_key }) + = .ok iter0 from rfl] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_loop_eq]; simp only [Aeneas.Std.bind_tc_ok] + show (do + let result1 ← polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst (Aeneas.Std.Array.to_slice acc2) result + let (result2, scratch1) ← + invert_ntt.invert_ntt_montgomery K portable_ops_inst result1 scratch + let (result3, scratch2) ← + polynomial.PolynomialRingElement.add_message_error_reduce + portable_ops_inst error_2 message result2 scratch1 + Result.ok (t_ent1, result3, scratch2, acc2)) + = Result.ok (t_ent1, result3, scratch2, acc2) + rw [← h_s_def, h_result1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_inv_eq]; simp only [Aeneas.Std.bind_tc_ok] + show (do + let (result3, scratch2) ← + polynomial.PolynomialRingElement.add_message_error_reduce + portable_ops_inst error_2 message result2 scratch1 + Result.ok (t_ent1, result3, scratch2, acc2)) + = Result.ok (t_ent1, result3, scratch2, acc2) + rw [h_add_eq]; rfl + · -- Chain A/C/B/compose/glue/D′: prove hacspec spec = .ok (lift_poly result3). + show hacspec_ml_kem.matrix.compute_ring_element_v + (lift_t_as_ntt_from_public_key public_key K) (lift_vec_slice r_as_ntt K) + (lift_poly error_2) (lift_poly message) = .ok (lift_poly result3) + unfold hacspec_ml_kem.matrix.compute_ring_element_v + -- A: multiply_vectors = .ok (scaleZ 2285 (lift_poly result1)). + have hA := compute_ring_element_v_acc_bridge hK public_key r_as_ntt r_arr + acc1 acc2 h_acc1_zero h_r_arr h_r_bnd t_ent1 h_char + rw [← h_result1_lift] at hA + rw [hA]; simp only [Aeneas.Std.bind_tc_ok] + -- C: ntt_inverse (scaleZ 2285 (lift_poly result1)) + -- = .ok (scaleZ 3303 (invert_pure (scaleZ 2285 (lift_poly result1)))). + have hCanon_s : ∀ j : Nat, j < 256 → + libcrux_iot_ml_kem.Spec.Pure.Canonical + ((scaleZ 2285 (lift_poly result1)).val[j]!) := + fun j hj => scaleZ_canon 2285 (lift_poly result1) j hj + rw [ntt_inverse_eq_scaleZ_invert_pure (scaleZ 2285 (lift_poly result1)) hCanon_s] + simp only [Aeneas.Std.bind_tc_ok] + -- B: invert_pure (scaleZ 2285 x) = scaleZ 2285 (invert_pure x). + rw [invert_ntt_montgomery_pure_scaleZ 2285 (lift_poly result1) + (fun j hj => lift_poly_canon result1 j hj)] + -- scaleZ 3303 (scaleZ 2285 y) = scaleZ 512 y. + rw [scaleZ_compose 3303 2285 (Spec.invert_ntt_montgomery_pure (lift_poly result1)), + glue_3303_2285] + -- invert_pure (lift_poly result1) = lift_poly result2. + rw [← h_result2_lift] + -- D′: (add_polynomials (scaleZ 512 (lift_poly result2)) (lift_poly error_2) + -- >>= add_polynomials · (lift_poly message)) + -- = .ok (add_message_error_reduce_pure (lift_poly error_2) (lift_poly message) + -- (lift_poly result2)). + rw [add_message_error_scaleZ_eq (lift_poly result2) (lift_poly error_2) (lift_poly message) + (fun j hj => lift_poly_canon result2 j hj)] + -- add_message_error_reduce_pure (lift_poly error_2) (lift_poly message) (lift_poly result2) + -- = lift_poly result3. + rw [← h_result3_lift] + +/-- +info: 'libcrux_iot_ml_kem.Matrix.ComputeRingElementV.FC.compute_ring_element_v_fc' depends on axioms: [propext, + Classical.choice, + Quot.sound, + deserialize_to_reduced_ring_element_fc]-/ +#guard_msgs in +#print axioms compute_ring_element_v_fc + +end libcrux_iot_ml_kem.Matrix.ComputeRingElementV.FC \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeRingElementV/Hacspec.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeRingElementV/Hacspec.lean new file mode 100644 index 00000000..aadc5d65 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeRingElementV/Hacspec.lean @@ -0,0 +1,244 @@ +/- + # `Matrix/ComputeRingElementV/Hacspec.lean` — L7.3 `D′` tail bridge. + + The hacspec↔pure equational bridge for the *add-message-error* tail of the + `compute_ring_element_v` decomposition. The hacspec `compute_ring_element_v` + tail unfolds to + `let a ← add_polynomials inner_product_inv error_2; + add_polynomials a message`, + where the L7.3 bridge supplies `inner_product_inv = scaleZ 512 result2`. + + This file proves the *outer* `add_polynomials … message` step on top of the + already-proven *inner* bridge `ComputeVectorU.add_polynomials_scaleZ_eq` + (the `add_polynomials (scaleZ 512 b) e = add_error_reduce_pure b e` lemma): + + add_polynomials (add_error_reduce_pure result2 e2) msg + = add_message_error_reduce_pure e2 msg result2. + + Per-lane both sides reduce, in `ZMod 3329`, to `512·result2 + e2 + msg`. + The proof mirrors `add_polynomials_scaleZ_eq` exactly: + `matrix_add_polynomials_eq_ok` reduction, `eq_of_zmod_lane_canon''` + 3-part split, per-lane `add_pure` associativity/commutativity closed by + `ring`. + + Local copies of the `private` lane-access / canonicity helpers from + `ComputeVectorU` are re-derived here (the originals are `private`). +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Matrix.ComputeMessage.Hacspec +import LibcruxIotMlKem.Matrix.ComputeVectorU.Hacspec + +set_option mvcgen.warning false +set_option linter.unusedVariables false + +namespace libcrux_iot_ml_kem.Matrix.ComputeRingElementV.Hacspec +open libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges libcrux_iot_ml_kem.Matrix.ComputeMessage.Hacspec libcrux_iot_ml_kem.Matrix.ComputeVectorU.Hacspec +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open libcrux_iot_ml_kem.Spec.Pure (Canonical) +section AddMessageErrorScaleZ + +/-! ### Local lane-access / canonicity helpers (copies of the `private` + originals in `ComputeVectorU`). -/ + +/-- Generic `Std.Array.make … (range m).map f` lane access (local copy). -/ +private theorem mkN_map_lane'' {α : Type} [Inhabited α] {n : Std.Usize} {m : Nat} + (f : Nat → α) (k : Nat) (hk : k < m) + (hlen : ((List.range m).map f).length = n.val) : + (Std.Array.make n ((List.range m).map f) hlen).val[k]! = f k := by + show ((List.range m).map f)[k]! = f k + have h_len : ((List.range m).map f).length = m := by simp + rw [getElem!_pos _ k (by rw [h_len]; exact hk)] + simp + +/-- `chunk_at` lane access (local copy). -/ +private theorem chunk_at_lane'' + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (k ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_at p k).val[ℓ]! = p.val[16 * k + ℓ]! := by + unfold Spec.chunk_at + show ((List.range 16).map (fun j => p.val[16 * k + j]!))[ℓ]! = p.val[16 * k + ℓ]! + have h_len : ((List.range 16).map (fun j => p.val[16 * k + j]!)).length = 16 := by simp + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map, List.getElem_range] + +/-- Lane access for a 16-chunk flatten shape (local copy). -/ +private theorem flatten_chunk_map_lane'' + (H : Nat → Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (j : Nat) (hj : j < 256) + (h : ((List.range 16).map H).length = (16#usize).val) : + (Spec.flatten_chunks (Std.Array.make 16#usize ((List.range 16).map H) h)).val[j]! + = (H (j / 16)).val[j % 16]! := by + have hk : j / 16 < 16 := by omega + unfold Spec.flatten_chunks + rw [mkN_map_lane'' _ j hj] + rw [mkN_map_lane'' H (j / 16) hk] + +/-- Canonical round-trip (local copy). -/ +private theorem canon_feOfZMod'' (z : ZMod 3329) : Canonical (feOfZMod z) := by + unfold Canonical feOfZMod hacspec_ml_kem.parameters.FIELD_MODULUS + show (BitVec.ofNat 16 z.val).toNat < _ + rw [BitVec.toNat_ofNat] + have hz : z.val < 3329 := ZMod.val_lt z + have : z.val % 2 ^ 16 = z.val := Nat.mod_eq_of_lt (by omega) + simp only [this]; simpa using hz + +/-- Canonical round-trip (local copy). -/ +private theorem feOfZMod_zmodOfFE_of_canon'' + (fe : hacspec_ml_kem.parameters.FieldElement) (h : Canonical fe) : + feOfZMod (zmodOfFE fe) = fe := by + have h' : fe.val.val < 3329 := by + unfold Canonical hacspec_ml_kem.parameters.FIELD_MODULUS at h; simpa using h + unfold feOfZMod zmodOfFE + have hzval : ((fe.val.val : ZMod 3329)).val = fe.val.val := ZMod.val_natCast_of_lt h' + rw [hzval] + have hfeval : fe.val.val < 2 ^ 16 := by + have h_p : (3329 : Nat) ≤ 2 ^ 16 := by decide + omega + have hfebv : BitVec.ofNat 16 fe.val.val = fe.val.bv := by + apply BitVec.eq_of_toNat_eq + rw [BitVec.toNat_ofNat] + show fe.val.val % 2 ^ 16 = fe.val.bv.toNat + rw [Nat.mod_eq_of_lt hfeval]; rfl + show ({ val := ⟨BitVec.ofNat 16 fe.val.val⟩ } : + hacspec_ml_kem.parameters.FieldElement) = fe + rw [hfebv] + +/-- Two canonical 256-arrays with equal `zmodOfFE` lanes are equal (local copy). -/ +private theorem eq_of_zmod_lane_canon'' + (u v : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hcu : ∀ j : Nat, j < 256 → Canonical (u.val[j]!)) + (hcv : ∀ j : Nat, j < 256 → Canonical (v.val[j]!)) + (hz : ∀ j : Nat, j < 256 → zmodOfFE (u.val[j]!) = zmodOfFE (v.val[j]!)) : + u = v := by + apply Subtype.ext + apply List.ext_getElem + · rw [Aeneas.Std.Array.length_eq u, Aeneas.Std.Array.length_eq v] + · intro j hj1 _hj2 + have hj : j < 256 := by rw [Aeneas.Std.Array.length_eq u] at hj1; simpa using hj1 + have heq : u.val[j]! = v.val[j]! := by + rw [← feOfZMod_zmodOfFE_of_canon'' (u.val[j]!) (hcu j hj), + ← feOfZMod_zmodOfFE_of_canon'' (v.val[j]!) (hcv j hj), hz j hj] + have huj : u.val[j]! = u.val[j] := + getElem!_pos u.val j (by rw [Aeneas.Std.Array.length_eq u]; exact hj) + have hvj : v.val[j]! = v.val[j] := + getElem!_pos v.val j (by rw [Aeneas.Std.Array.length_eq v]; exact hj) + rw [← huj, ← hvj]; exact heq + +/-- Per-lane characterization of `Spec.add_error_reduce_pure` (local copy of + `ComputeVectorU.zmodOfFE_add_error_reduce_pure_lane`): for `j < 256` and + canonical `b[j]`, + `zmodOfFE ((add_error_reduce_pure b e)[j]) = 512·zmodOfFE (b[j]) + zmodOfFE (e[j])`. -/ +private theorem zmodOfFE_add_error_reduce_pure_lane'' + (b e : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) + (hb : libcrux_iot_ml_kem.Spec.Pure.Canonical (b.val[j]!)) : + zmodOfFE ((Spec.add_error_reduce_pure b e).val[j]!) + = 512 * zmodOfFE (b.val[j]!) + zmodOfFE (e.val[j]!) := by + have hℓ : j % 16 < 16 := Nat.mod_lt _ (by decide) + have hjeq : 16 * (j / 16) + j % 16 = j := by omega + unfold Spec.add_error_reduce_pure + rw [flatten_chunk_map_lane'' (fun k => Spec.chunk_add_error_reduce_pure + (Spec.chunk_at b k) (Spec.chunk_at e k)) j hj (by simp)] + unfold Spec.chunk_add_error_reduce_pure + rw [mkN_map_lane'' _ (j % 16) hℓ] + rw [chunk_at_lane'' b (j / 16) (j % 16) hℓ, chunk_at_lane'' e (j / 16) (j % 16) hℓ] + rw [hjeq] + rw [zmodOfFE_add_pure] + rw [zmodOfFE_mul_pure] + rw [zmodOfFE_lift_fe_mont] + have h1441 : (((1441#i16 : Std.I16).val : ZMod 3329)) = 1441 := by decide + rw [h1441] + have h512 : (1441 : ZMod 3329) * 169 = 512 := glue_1441_169 + rw [show (zmodOfFE (b.val[j]!) * (1441 * 169) : ZMod 3329) + = 512 * zmodOfFE (b.val[j]!) by rw [h512]; ring] + +/-! ### `D′` — the `add_polynomials ∘ add_polynomials ∘ scaleZ` tail bridge. -/ + +set_option maxHeartbeats 1000000 in +/-- **L7.3 `D′` tail bridge.** The hacspec `compute_ring_element_v` tail + `add_polynomials (add_polynomials (scaleZ 512 result2) e2) msg` + equals `Spec.add_message_error_reduce_pure e2 msg result2` for canonical + `result2`. The inner add reuses `add_polynomials_scaleZ_eq`; the outer add + mirrors its body (createi reduction + lane split + per-lane `ring`). -/ +theorem add_message_error_scaleZ_eq + (result2 e2 msg : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (h_r2_canon : ∀ j : Nat, j < 256 → + libcrux_iot_ml_kem.Spec.Pure.Canonical (result2.val[j]!)) : + (do let a ← hacspec_ml_kem.matrix.add_polynomials (scaleZ 512 result2) e2 + hacspec_ml_kem.matrix.add_polynomials a msg) + = .ok (Spec.add_message_error_reduce_pure e2 msg result2) := by + -- Inner add: reuse the proven bridge. + rw [add_polynomials_scaleZ_eq result2 e2 h_r2_canon] + simp only [Aeneas.Std.bind_tc_ok] + -- Outer add: reduce via createi. + rw [Stage4MatrixAddFC.matrix_add_polynomials_eq_ok (Spec.add_error_reduce_pure result2 e2) msg] + set L : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + ⟨(List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.add_error_reduce_pure result2 e2).val[k]!) (msg.val[k]!)), + by simp [List.length_map, List.length_range]⟩ with hL_def + have hL_lane : ∀ j : Nat, j < 256 → + L.val[j]! = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.add_error_reduce_pure result2 e2).val[j]!) (msg.val[j]!) := by + intro j hj + show ((List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.add_error_reduce_pure result2 e2).val[k]!) (msg.val[k]!)))[j]! = _ + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + apply congrArg Result.ok + apply eq_of_zmod_lane_canon'' + · -- L lanes canonical + intro j hj + rw [hL_lane j hj] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + · -- add_message_error_reduce_pure lanes canonical + intro j hj + have hℓ : j % 16 < 16 := Nat.mod_lt _ (by decide) + unfold Spec.add_message_error_reduce_pure + rw [flatten_chunk_map_lane'' (fun k => Spec.chunk_add_message_error_reduce_pure + (Spec.chunk_at e2 k) (Spec.chunk_at msg k) (Spec.chunk_at result2 k)) j hj (by simp)] + unfold Spec.chunk_add_message_error_reduce_pure + rw [mkN_map_lane'' _ (j % 16) hℓ] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + · -- per-lane zmodOfFE equality + intro j hj + have hℓ : j % 16 < 16 := Nat.mod_lt _ (by decide) + have hjeq : 16 * (j / 16) + j % 16 = j := by omega + rw [hL_lane j hj] + rw [zmodOfFE_add_pure] + rw [zmodOfFE_add_error_reduce_pure_lane'' result2 e2 j hj (h_r2_canon j hj)] + -- RHS: unfold add_message_error_reduce_pure lane. + unfold Spec.add_message_error_reduce_pure + rw [flatten_chunk_map_lane'' (fun k => Spec.chunk_add_message_error_reduce_pure + (Spec.chunk_at e2 k) (Spec.chunk_at msg k) (Spec.chunk_at result2 k)) j hj (by simp)] + unfold Spec.chunk_add_message_error_reduce_pure + rw [mkN_map_lane'' _ (j % 16) hℓ] + rw [chunk_at_lane'' e2 (j / 16) (j % 16) hℓ, chunk_at_lane'' msg (j / 16) (j % 16) hℓ, + chunk_at_lane'' result2 (j / 16) (j % 16) hℓ] + rw [hjeq] + rw [zmodOfFE_add_pure, zmodOfFE_mul_pure, zmodOfFE_add_pure, zmodOfFE_lift_fe_mont] + have h1441 : (((1441#i16 : Std.I16).val : ZMod 3329)) = 1441 := by decide + rw [h1441] + have h512 : (1441 : ZMod 3329) * 169 = 512 := glue_1441_169 + rw [show (zmodOfFE (result2.val[j]!) * (1441 * 169) : ZMod 3329) + = 512 * zmodOfFE (result2.val[j]!) by rw [h512]; ring] + ring + +end AddMessageErrorScaleZ + +end libcrux_iot_ml_kem.Matrix.ComputeRingElementV.Hacspec \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeRingElementV/Impl.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeRingElementV/Impl.lean new file mode 100644 index 00000000..301b11de --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeRingElementV/Impl.lean @@ -0,0 +1,1067 @@ +/- + # `Matrix/ComputeRingElementV/Impl.lean` — L7.3 chunks-exact loop keystone. + + Domain-free iterator infrastructure for the `loop` in + `matrix.compute_ring_element_v_loop`, which iterates over a + `Enumerate (ChunksExact U8)` via the `loop` combinator. There is no Mathlib, + no ZMod, no mont reasoning here — purely Slice / Usize / iterator facts. + + Mirrors the EXISTING range-loop keystone `loop_range_spec_usize`: same induction skeleton, with `e.val - start.val` + replaced by `numChunks - k` (k = enumerate count). + + Deliverables: + * `enumerate_chunks_next_cont` / `enumerate_chunks_next_done` — equational + characterization of `Enumerate.next` over `ChunksExact`. + * `loop_chunks_exact_enumerate_spec` — the loop Hoare spec (keystone). +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Sampling +import LibcruxIotMlKem.Serialize +import LibcruxIotMlKem.Matrix.ComputeMessage.Impl +import LibcruxIotMlKem.Matrix.ComputeMessage.Hacspec + +namespace libcrux_iot_ml_kem.Matrix.ComputeRingElementV.Impl +open libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges libcrux_iot_ml_kem.Matrix.ComputeMessage.Hacspec libcrux_iot_ml_kem.Matrix.ComputeMessage.Impl +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open Result ControlFlow + +set_option mvcgen.warning false +set_option linter.unusedVariables false + +/-- Local copy of FCTargets' `private triple_exists_ok_fc`. -/ +private theorem triple_exists_ok_fc {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- The concrete `Enumerate (ChunksExact U8)` iterator state type. -/ +abbrev EnumCE := + CoreModels.core.iter.adapters.enumerate.Enumerate + (CoreModels.core.slice.iter.ChunksExact Std.U8) + +/-- The fully-applied `Enumerate.next` function for our iterator. -/ +noncomputable abbrev enumCENext (it : EnumCE) : + Result ((CoreModels.core.option.Option (Std.Usize × Slice Std.U8)) × EnumCE) := + CoreModels.core.iter.adapters.enumerate.Enumerate.Insts.CoreIterTraitsIteratorIteratorPairUsizeClause0_Item.next + (CoreModels.core.slice.iter.ChunksExact.Insts.CoreIterTraitsIteratorIteratorSharedASlice + Std.U8) it + +/-! ## Deliverable 1 — `Enumerate.next` characterization -/ + +/-- Enumerate-over-ChunksExact `next` when at least `cs` bytes remain: + yields the current count paired with the first `cs` bytes, advances. -/ +theorem enumerate_chunks_next_cont + (rest : Slice Std.U8) (cs cnt : Std.Usize) + (h_le : cs.val ≤ rest.length) (h_cnt : cnt.val + 1 ≤ Std.Usize.max) : + ∃ (chunk drop : Slice Std.U8) (cnt' : Std.Usize), + enumCENext { iter := { cs := cs, elements := rest }, count := cnt } + = .ok (CoreModels.core.option.Option.Some (cnt, chunk), + { iter := { cs := cs, elements := drop }, count := cnt' }) + ∧ cnt'.val = cnt.val + 1 + ∧ chunk.length = cs.val + ∧ drop.length = rest.length - cs.val + ∧ (∀ ℓ : Nat, ℓ < cs.val → chunk.val[ℓ]! = rest.val[ℓ]!) := by + -- `split_at` succeeds since `cs ≤ rest.length`. + have hsa := core.slice.Slice.split_at.spec rest cs (by simpa using h_le) + obtain ⟨⟨s0, s1⟩, hsa_eq, hs0len, hs1len, hs0val, hs1val⟩ := WP.spec_imp_exists hsa + -- `cnt + 1#usize` succeeds. + have hadd := Std.Usize.add_spec (x := cnt) (y := 1#usize) + (by have : (1#usize : Std.Usize).val = 1 := rfl; omega) + obtain ⟨cnt', hcnt'_eq, hcnt'_post⟩ := WP.spec_imp_exists hadd + have hcnt'_val : cnt'.val = cnt.val + 1 := by + have h1 : (1#usize : Std.Usize).val = 1 := rfl + simp only [h1] at hcnt'_post; omega + refine ⟨s0, s1, cnt', ?_, hcnt'_val, hs0len, hs1len, ?_⟩ + · -- compute the `next` + simp only [enumCENext, + CoreModels.core.iter.adapters.enumerate.Enumerate.Insts.CoreIterTraitsIteratorIteratorPairUsizeClause0_Item.next, + CoreModels.core.slice.iter.ChunksExact.Insts.CoreIterTraitsIteratorIteratorSharedASlice.next] + have h_le' : cs.val ≤ rest.val.length := by simpa [Slice.length] using h_le + rw [hsa_eq] + simp only [h_le', ↓reduceIte, bind_assoc] + rw [hcnt'_eq] + rfl + · -- the element relation + intro ℓ hℓ + rw [hs0val, List.getElem!_take_of_lt _ _ _ hℓ] + +/-- Enumerate-over-ChunksExact `next` when fewer than `cs` bytes remain: + terminates. -/ +theorem enumerate_chunks_next_done + (rest : Slice Std.U8) (cs cnt : Std.Usize) (h_lt : rest.length < cs.val) : + enumCENext { iter := { cs := cs, elements := rest }, count := cnt } + = .ok (CoreModels.core.option.Option.None, + { iter := { cs := cs, elements := rest }, count := cnt }) := by + simp only [enumCENext, + CoreModels.core.iter.adapters.enumerate.Enumerate.Insts.CoreIterTraitsIteratorIteratorPairUsizeClause0_Item.next, + CoreModels.core.slice.iter.ChunksExact.Insts.CoreIterTraitsIteratorIteratorSharedASlice.next] + have hng : ¬ (cs.val ≤ rest.val.length) := by + simp only [Slice.length] at h_lt; omega + simp only [hng, ↓reduceIte] + rfl + +/-! ## Deliverable 2 — the loop Hoare spec keystone + +Mirrors `loop_range_spec_usize`: same induction +skeleton with `e.val - start.val` replaced by `numChunks - k`. The three +`triple_noThrow_*_chunks` helpers below are verbatim copies of the +`triple_noThrow_*_usize` machinery, renamed to avoid clashes. -/ + +section loop_chunks_helpers + +private abbrev ResultPSU := PostShape.except Error (PostShape.except PUnit PostShape.pure) + +private theorem triple_noThrow_elim_chunks {α : Type} {x : Result α} + {Q : α → Assertion ResultPSU} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ PostCond.noThrow Q ⦄) {v : α} (hv : x = ok v) : + (Q v).down := by + subst hv; simpa [Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h + +private theorem triple_noThrow_exists_ok_chunks {α : Type} {x : Result α} + {Q : α → Assertion ResultPSU} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ PostCond.noThrow Q ⦄) : ∃ v, x = ok v := by + match x, h with + | .ok v, _ => exact ⟨v, rfl⟩ + | .fail _, h => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div, h => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +private theorem triple_of_ok_chunks {α : Type} {x : Result α} {v : α} {P : α → Prop} + (hx : x = ok v) (hp : P v) : + (⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) := by + subst hx; simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +end loop_chunks_helpers + +set_option maxHeartbeats 2000000 in +/-- Loop-over-`Enumerate (ChunksExact U8)` spec. An invariant `inv : Nat → β → + Result Prop`, indexed by the enumerate count `k`, is preserved by each step. + Induction on `numChunks - k`. -/ +theorem loop_chunks_exact_enumerate_spec {β : Type} + (body : (EnumCE × β) → Result (ControlFlow (EnumCE × β) β)) + (init : β) (fullSlice : Slice Std.U8) (cs : Std.Usize) (numChunks : Nat) + (inv : Nat → β → Result Prop) + (h_cs_pos : 0 < cs.val) + (h_len : fullSlice.length = numChunks * cs.val) + (h_init : (inv 0 init).holds) + (h_step : ∀ (acc : β) (k : Nat) (rest : Slice Std.U8) (cnt : Std.Usize), + k ≤ numChunks → cnt.val = k → rest.length = (numChunks - k) * cs.val → + (inv k acc).holds → + ⦃ ⌜ True ⌝ ⦄ + body ({ iter := { cs := cs, elements := rest }, count := cnt }, acc) + ⦃ ⇓ r => match r with + | .cont (iter', acc') => + ⌜ k < numChunks ∧ iter'.iter.cs = cs ∧ iter'.count.val = k + 1 + ∧ iter'.iter.elements.length = (numChunks - (k + 1)) * cs.val + ∧ (inv (k + 1) acc').holds ⌝ + | .done y => ⌜ (inv numChunks y).holds ⌝ ⦄) : + ⦃ ⌜ True ⌝ ⦄ + loop body ({ iter := { cs := cs, elements := fullSlice }, count := 0#usize }, init) + ⦃ ⇓ r => ⌜ (inv numChunks r).holds ⌝ ⦄ := by + suffices gen : ∀ (n : Nat) (acc : β) (rest : Slice Std.U8) (cnt : Std.Usize), + numChunks - cnt.val = n → cnt.val ≤ numChunks → + rest.length = (numChunks - cnt.val) * cs.val → + (inv cnt.val acc).holds → + ⦃ ⌜ True ⌝ ⦄ + loop body ({ iter := { cs := cs, elements := rest }, count := cnt }, acc) + ⦃ ⇓ r => ⌜ (inv numChunks r).holds ⌝ ⦄ by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + refine gen _ init fullSlice 0#usize rfl (by rw [h0]; exact Nat.zero_le _) ?_ ?_ + · rw [h0]; simpa using h_len + · rw [h0]; exact h_init + intro n + induction n with + | zero => + intro acc rest cnt hn hcnt_le hlen hinv + -- numChunks - cnt.val = 0 with cnt.val ≤ numChunks ⟹ cnt.val = numChunks + have hcnt_eq : cnt.val = numChunks := by omega + have hs := h_step acc cnt.val rest cnt hcnt_le rfl hlen hinv + obtain ⟨r, hbody⟩ := triple_noThrow_exists_ok_chunks hs + have hpost := triple_noThrow_elim_chunks hs hbody + rw [loop.eq_def, hbody] + match r with + | .cont (iter', acc') => + simp only at hpost + exact absurd hpost.1 (by rw [hcnt_eq]; exact Nat.lt_irrefl _) + | .done y => + simp only at hpost + exact triple_of_ok_chunks rfl hpost + | succ n ih => + intro acc rest cnt hn hcnt_le hlen hinv + have hcnt_lt : cnt.val < numChunks := by omega + have hs := h_step acc cnt.val rest cnt hcnt_le rfl hlen hinv + obtain ⟨r, hbody⟩ := triple_noThrow_exists_ok_chunks hs + have hpost := triple_noThrow_elim_chunks hs hbody + rw [loop.eq_def, hbody] + match r with + | .done y => + simp only at hpost + exact triple_of_ok_chunks rfl hpost + | .cont (iter', acc') => + simp only at hpost + obtain ⟨hlt, hcs, hcnt', hlen', hinv'⟩ := hpost + -- reconstruct iter' from its fields + have hiter : iter' + = { iter := { cs := cs, elements := iter'.iter.elements }, + count := iter'.count } := by + cases iter' with + | mk it ct => cases it with | mk csv el => cases hcs; rfl + rw [hiter] + refine ih acc' iter'.iter.elements iter'.count ?_ ?_ ?_ ?_ + · rw [hcnt']; omega + · rw [hcnt']; omega + · rw [hcnt']; exact hlen' + · rw [hcnt']; exact hinv' + +/-! ## Deliverable 1b — public-key-suffix-aware loop spec (L7.3 keystone) + +The generic `loop_chunks_exact_enumerate_spec` exposes `rest` in `h_step` with +only a length hypothesis, NOT the byte-content relation needed to apply the A2 +axiom `deserialize_to_reduced_ring_element_fc` (which requires +`chunk_bytes.val[ℓ]! = public_key.val[i*cs+ℓ]!`). This specialized keystone +threads the suffix relation `rest.val[ℓ]! = fullSlice.val[k*cs+ℓ]!` through the +induction, so `h_step` receives it at each `k`. -/ + +/-- Enumerate-over-ChunksExact `next` `drop`-suffix companion: in the `.cont` + case the advanced slice `drop` is the `cs`-tail of `rest`, i.e. + `drop.val[ℓ]! = rest.val[cs + ℓ]!`. Proven from `split_at.spec`'s + `s1.val = s.val.drop cs`. -/ +theorem enumerate_chunks_next_cont_drop + (rest : Slice Std.U8) (cs cnt : Std.Usize) + (h_le : cs.val ≤ rest.length) (h_cnt : cnt.val + 1 ≤ Std.Usize.max) : + ∃ (chunk drop : Slice Std.U8) (cnt' : Std.Usize), + enumCENext { iter := { cs := cs, elements := rest }, count := cnt } + = .ok (CoreModels.core.option.Option.Some (cnt, chunk), + { iter := { cs := cs, elements := drop }, count := cnt' }) + ∧ cnt'.val = cnt.val + 1 + ∧ chunk.length = cs.val + ∧ drop.length = rest.length - cs.val + ∧ (∀ ℓ : Nat, ℓ < cs.val → chunk.val[ℓ]! = rest.val[ℓ]!) + ∧ (∀ ℓ : Nat, drop.val[ℓ]! = rest.val[cs.val + ℓ]!) := by + have hsa := core.slice.Slice.split_at.spec rest cs (by simpa using h_le) + obtain ⟨⟨s0, s1⟩, hsa_eq, hs0len, hs1len, hs0val, hs1val⟩ := WP.spec_imp_exists hsa + have hadd := Std.Usize.add_spec (x := cnt) (y := 1#usize) + (by have : (1#usize : Std.Usize).val = 1 := rfl; omega) + obtain ⟨cnt', hcnt'_eq, hcnt'_post⟩ := WP.spec_imp_exists hadd + have hcnt'_val : cnt'.val = cnt.val + 1 := by + have h1 : (1#usize : Std.Usize).val = 1 := rfl + simp only [h1] at hcnt'_post; omega + refine ⟨s0, s1, cnt', ?_, hcnt'_val, hs0len, hs1len, ?_, ?_⟩ + · simp only [enumCENext, + CoreModels.core.iter.adapters.enumerate.Enumerate.Insts.CoreIterTraitsIteratorIteratorPairUsizeClause0_Item.next, + CoreModels.core.slice.iter.ChunksExact.Insts.CoreIterTraitsIteratorIteratorSharedASlice.next] + have h_le' : cs.val ≤ rest.val.length := by simpa [Slice.length] using h_le + rw [hsa_eq] + simp only [h_le', ↓reduceIte, bind_assoc] + rw [hcnt'_eq] + rfl + · intro ℓ hℓ + rw [hs0val, List.getElem!_take_of_lt _ _ _ hℓ] + · intro ℓ + rw [hs1val] + by_cases hℓ : ℓ < (List.drop cs.val rest.val).length + · rw [getElem!_pos (List.drop cs.val rest.val) ℓ hℓ] + have hℓ' : cs.val + ℓ < rest.val.length := by + rw [List.length_drop] at hℓ; omega + rw [getElem!_pos rest.val (cs.val + ℓ) hℓ'] + rw [List.getElem_drop] + · -- out-of-range: both default. + have hℓr : ¬ (cs.val + ℓ < rest.val.length) := by + rw [List.length_drop] at hℓ; omega + rw [getElem!_neg (List.drop cs.val rest.val) ℓ hℓ, + getElem!_neg rest.val (cs.val + ℓ) hℓr] + +set_option maxHeartbeats 2000000 in +/-- Public-key-suffix-aware loop spec: like `loop_chunks_exact_enumerate_spec`, + but additionally threads the byte-content suffix relation + `rest.val[ℓ]! = fullSlice.val[k*cs+ℓ]!` to `h_step` at each `k`. The body + receives, at count `k`, the slice `rest` positioned at byte-offset `k*cs` + in `fullSlice`. Induction on `numChunks - k`, carrying the suffix relation + `∀ ℓ, rest.val[ℓ]! = fullSlice.val[cnt.val*cs.val + ℓ]!`. -/ +theorem loop_chunks_exact_pk_spec {β : Type} + (body : (EnumCE × β) → Result (ControlFlow (EnumCE × β) β)) + (init : β) (fullSlice : Slice Std.U8) (cs : Std.Usize) (numChunks : Nat) + (inv : Nat → β → Result Prop) + (h_cs_pos : 0 < cs.val) + (h_len : fullSlice.length = numChunks * cs.val) + (h_init : (inv 0 init).holds) + (h_step : ∀ (acc : β) (k : Nat) (rest : Slice Std.U8) (cnt : Std.Usize), + k ≤ numChunks → cnt.val = k → rest.length = (numChunks - k) * cs.val → + (∀ ℓ : Nat, rest.val[ℓ]! = fullSlice.val[k * cs.val + ℓ]!) → + (inv k acc).holds → + ⦃ ⌜ True ⌝ ⦄ + body ({ iter := { cs := cs, elements := rest }, count := cnt }, acc) + ⦃ ⇓ r => match r with + | .cont (iter', acc') => + ⌜ k < numChunks ∧ iter'.iter.cs = cs ∧ iter'.count.val = k + 1 + ∧ iter'.iter.elements.length = (numChunks - (k + 1)) * cs.val + ∧ (∀ ℓ : Nat, iter'.iter.elements.val[ℓ]! + = fullSlice.val[(k + 1) * cs.val + ℓ]!) + ∧ (inv (k + 1) acc').holds ⌝ + | .done y => ⌜ (inv numChunks y).holds ⌝ ⦄) : + ⦃ ⌜ True ⌝ ⦄ + loop body ({ iter := { cs := cs, elements := fullSlice }, count := 0#usize }, init) + ⦃ ⇓ r => ⌜ (inv numChunks r).holds ⌝ ⦄ := by + suffices gen : ∀ (n : Nat) (acc : β) (rest : Slice Std.U8) (cnt : Std.Usize), + numChunks - cnt.val = n → cnt.val ≤ numChunks → + rest.length = (numChunks - cnt.val) * cs.val → + (∀ ℓ : Nat, rest.val[ℓ]! = fullSlice.val[cnt.val * cs.val + ℓ]!) → + (inv cnt.val acc).holds → + ⦃ ⌜ True ⌝ ⦄ + loop body ({ iter := { cs := cs, elements := rest }, count := cnt }, acc) + ⦃ ⇓ r => ⌜ (inv numChunks r).holds ⌝ ⦄ by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + refine gen _ init fullSlice 0#usize rfl (by rw [h0]; exact Nat.zero_le _) ?_ ?_ ?_ + · rw [h0]; simpa using h_len + · rw [h0]; intro ℓ; simp + · rw [h0]; exact h_init + intro n + induction n with + | zero => + intro acc rest cnt hn hcnt_le hlen hsuf hinv + have hcnt_eq : cnt.val = numChunks := by omega + have hs := h_step acc cnt.val rest cnt hcnt_le rfl hlen + (by rw [hcnt_eq] at hsuf ⊢; exact hsuf) hinv + obtain ⟨r, hbody⟩ := triple_noThrow_exists_ok_chunks hs + have hpost := triple_noThrow_elim_chunks hs hbody + rw [loop.eq_def, hbody] + match r with + | .cont (iter', acc') => + simp only at hpost + exact absurd hpost.1 (by rw [hcnt_eq]; exact Nat.lt_irrefl _) + | .done y => + simp only at hpost + exact triple_of_ok_chunks rfl hpost + | succ n ih => + intro acc rest cnt hn hcnt_le hlen hsuf hinv + have hcnt_lt : cnt.val < numChunks := by omega + have hs := h_step acc cnt.val rest cnt hcnt_le rfl hlen hsuf hinv + obtain ⟨r, hbody⟩ := triple_noThrow_exists_ok_chunks hs + have hpost := triple_noThrow_elim_chunks hs hbody + rw [loop.eq_def, hbody] + match r with + | .done y => + simp only at hpost + exact triple_of_ok_chunks rfl hpost + | .cont (iter', acc') => + simp only at hpost + obtain ⟨hlt, hcs, hcnt', hlen', hsuf', hinv'⟩ := hpost + have hiter : iter' + = { iter := { cs := cs, elements := iter'.iter.elements }, + count := iter'.count } := by + cases iter' with + | mk it ct => cases it with | mk csv el => cases hcs; rfl + rw [hiter] + refine ih acc' iter'.iter.elements iter'.count ?_ ?_ ?_ ?_ ?_ + · rw [hcnt']; omega + · rw [hcnt']; omega + · rw [hcnt']; exact hlen' + · rw [hcnt']; exact hsuf' + · rw [hcnt']; exact hinv' + +/-! ## §L7.3-loop — chunks-exact USE-CACHE column loop (namespace `ChunkLoopFC`). + + The `matrix.compute_ring_element_v_loop` body, per enumerate count `i`: + * classify_ref the chunk (identity cast), + * `deserialize_to_reduced_ring_element` → `t̂[i]` (A2 axiom, the discarded + poly is captured in the existential witness `mp`), + * read `r_as_ntt[i]`, `cache[i]` (read-only), + * `accumulating_ntt_multiply_use_cache` → acc += t̂[i]·r[i]. + + Structurally `RowIFillFC.row_i_inv` (USE-CACHE, 2-conjunct existential-mp) with + the matrix factor pinned to `(lift_t_as_ntt_from_public_key public_key K)` + (the deserialize axiom) instead of `(lift_matrix_from_seed seed K).val[i]`. + The loop carries `(t_as_ntt_entry, acc)`. -/ + +namespace ChunkLoopFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Sampling libcrux_iot_ml_kem.Serialize libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +abbrev Acc := UseCacheFC.Acc +abbrev Poly := UseCacheFC.Poly + +/-- 2-conjunct invariant for the chunks-exact USE-CACHE column loop, in the + RESOLVED all-mont/existential form. `t̂rows` is the canonical deserialized + t-as-ntt rows `lift_t_as_ntt_from_public_key public_key K`. The impl + DISCARDS each deserialized poly (only `t_as_ntt_entry` = the last one and + the I32 accumulator survive), so we existentially quantify over the ACTUAL + deserialized polys `mp : Array Poly K`, tie them to the canonical rows via + the axiom (`lift_poly (mp[c]) = t̂rows[c]`), and characterize the + accumulator in the all-mont form. β = `Poly × Acc` (the carried + `t_as_ntt_entry` is ignored by the invariant). -/ +def loop_inv {K : Std.Usize} + (trows : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Acc) : + Nat → (Poly × Acc) → Result Prop := + fun k p => pure ( + (∃ mp : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K, + (∀ c : Nat, c < k → + lift_poly (mp.val[c]!) = trows.val[c]! + ∧ (∀ a : Fin 16, ∀ b : Fin 16, + ((mp.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328)) + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (p.2.val[16 * j + ℓ]!).val) + = (List.range k).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)))) + ∧ (∀ n : Nat, n < 256 → + (p.2.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + k * 2^25)) + +end ChunkLoopFC + +-- Memory hygiene (rule 1). Mirrors `L7_2b_irreducible`. We do NOT mark +-- `ChunkLoopFC.loop_inv` irreducible (preserve `simpa`-based destructure). +section L7_3_irreducible +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Sampling libcrux_iot_ml_kem.Serialize libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +attribute [local irreducible] accumulating_ntt_multiply_poly_post +attribute [local irreducible] accumulating_ntt_multiply_poly_cache_post +attribute [local irreducible] Spec.ntt_multiply_pure_no_acc +attribute [local irreducible] Spec.mont_reduce_pure + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the chunks-exact USE-CACHE column loop of + `compute_ring_element_v`. Dispatches: `Enumerate.next` (cont) → + `classify_ref` (identity) → `deserialize_to_reduced_ring_element` (A2 axiom, + fed the public-key-suffix relation) → `accumulating_ntt_multiply_use_cache`. + Re-establishes `loop_inv` at `k+1`. + + Mirrors `compute_vector_u_loop1_loop0_step_lemma_fc` (the SAMPLED USE-CACHE + row-i loop) with the range-iterator `.next` replaced by the enumerate-chunks + `.next` (via `enumerate_chunks_next_cont`) and `sample_matrix_entry_fc` + replaced by `deserialize_to_reduced_ring_element_fc`. -/ +private theorem compute_ring_element_v_loop_step_lemma_fc + {K : Std.Usize} + (public_key : Slice Std.U8) (h_pk_len : public_key.length = K.val * 384) + (r_as_ntt cache : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : ChunkLoopFC.Acc) + (h_r_len : r_as_ntt.length = K.val) + (h_cache_len : cache.length = K.val) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, + (acc_init.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) + (h_cache_char : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (r_as_ntt.val[c]!) (cache.val[c]!)) + (t_as_ntt_entry : ChunkLoopFC.Poly) + (acc : ChunkLoopFC.Acc) + (k : Nat) (h_le : k ≤ K.val) + (rest : Slice Std.U8) (cnt : Std.Usize) + (h_cnt : cnt.val = k) + (h_rest_len : rest.length = (K.val - k) * 384) + (h_suf : ∀ ℓ : Nat, rest.val[ℓ]! = public_key.val[k * 384 + ℓ]!) + (h_inv : (ChunkLoopFC.loop_inv (lift_t_as_ntt_from_public_key public_key K) r_arr acc_init + k (t_as_ntt_entry, acc)).holds) : + ⦃ ⌜ True ⌝ ⦄ + matrix.compute_ring_element_v_loop.body + (vectortraitsOperationsInst := portable_ops_inst) r_as_ntt cache + { iter := { cs := 384#usize, elements := rest }, count := cnt } t_as_ntt_entry acc + ⦃ ⇓ r => match r with + | .cont (iter', acc') => + ⌜ k < K.val ∧ iter'.iter.cs = 384#usize ∧ iter'.count.val = k + 1 + ∧ iter'.iter.elements.length = (K.val - (k + 1)) * 384 + ∧ (∀ ℓ : Nat, iter'.iter.elements.val[ℓ]! + = public_key.val[(k + 1) * 384 + ℓ]!) + ∧ (ChunkLoopFC.loop_inv (lift_t_as_ntt_from_public_key public_key K) r_arr acc_init + (k + 1) acc').holds ⌝ + | .done y => ⌜ (ChunkLoopFC.loop_inv (lift_t_as_ntt_from_public_key public_key K) r_arr + acc_init K.val y).holds ⌝ ⦄ := by + set trows : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + lift_t_as_ntt_from_public_key public_key K with htrows_def + have h_acc_len : acc.length = 256 := Std.Array.length_eq acc + have h_acc_init_len : acc_init.length = 256 := Std.Array.length_eq acc_init + -- Destructure the 2-conjunct invariant (`.2` of the carried pair reduces to `acc`). + obtain ⟨⟨mp, h_mp_agree, h_inv_acc⟩, h_inv_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + dsimp only at h_inv_acc h_inv_acc_bnd + unfold matrix.compute_ring_element_v_loop.body + by_cases h_lt : k < K.val + · -- `Some (cnt, chunk)` branch. + -- (1) Enumerate.next reduces to .ok (Some (cnt, chunk), advanced). + have h_cs_le : (384#usize : Std.Usize).val ≤ rest.length := by + rw [h_rest_len] + have : 1 ≤ K.val - k := by omega + have : (384#usize : Std.Usize).val = 384 := rfl + calc (384#usize : Std.Usize).val = 1 * 384 := by simp + _ ≤ (K.val - k) * 384 := by + apply Nat.mul_le_mul_right; omega + have hcnt_lt : cnt.val < K.val := by rw [h_cnt]; exact h_lt + have h_cnt_max : cnt.val + 1 ≤ Std.Usize.max := by + have hKmax : K.val ≤ Std.Usize.max := by + have h2 : public_key.length ≤ Std.Usize.max := public_key.property + rw [h_pk_len] at h2 + have : K.val ≤ K.val * 384 := Nat.le_mul_of_pos_right _ (by omega) + omega + omega + obtain ⟨chunk, drop, cnt', h_next_eq, h_cnt'_val, h_chunk_len, h_drop_len, + h_chunk_eq, h_drop_eq⟩ := + enumerate_chunks_next_cont_drop rest 384#usize cnt h_cs_le h_cnt_max + -- (3) deserialize via A2 axiom at index cnt (cnt.val = k). + have h_chunk_pk : ∀ ℓ : Nat, ℓ < 384 → + chunk.val[ℓ]! = public_key.val[cnt.val * 384 + ℓ]! := by + intro ℓ hℓ + have := h_chunk_eq ℓ (by have : (384#usize : Std.Usize).val = 384 := rfl; omega) + rw [this, h_suf ℓ, h_cnt] + obtain ⟨te1, h_te_eq, h_te_lift, h_te_bnd⟩ := + triple_exists_ok_fc + (deserialize_to_reduced_ring_element_fc public_key K t_as_ntt_entry + cnt h_pk_len hcnt_lt chunk h_chunk_len h_chunk_pk) + have h_te_lift' : lift_poly te1 = trows.val[k]! := by + rw [htrows_def, ← h_cnt]; exact h_te_lift + have h_te_bnd' : ∀ a : Fin 16, ∀ b : Fin 16, + ((te1.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_te_bnd a.val a.isLt b.val b.isLt + -- (4) Slice.index_usize r_as_ntt cnt → r_as_ntt[k]!. + set t_r : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + r_as_ntt.val[k]! with ht_r_def + have h_idx_r : Aeneas.Std.Slice.index_usize r_as_ntt cnt = .ok t_r := by + rw [ht_r_def, ← h_cnt] + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq r_as_ntt cnt + (by show cnt.val < r_as_ntt.length; rw [h_r_len, h_cnt]; exact h_lt) + -- (5) Slice.index_usize cache cnt → cache[k]!. + set t_cache : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + cache.val[k]! with ht_cache_def + have h_idx_cache : Aeneas.Std.Slice.index_usize cache cnt = .ok t_cache := by + rw [ht_cache_def, ← h_cnt] + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq cache cnt + (by show cnt.val < cache.length; rw [h_cache_len, h_cnt]; exact h_lt) + -- (6) per-column use-cache forward dep at column k. + have h_t_r_bnd : ∀ a : Fin 16, ∀ b : Fin 16, + ((t_r.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_r_bnd k h_lt a b + have h_cache_at_k : accumulating_ntt_multiply_poly_cache_post t_r t_cache := by + rw [ht_r_def, ht_cache_def]; exact h_cache_char k h_lt + have h_acc_cur_bnd : ∀ n : Fin 256, (acc.val[n.val]!).val.natAbs ≤ 2^30 := by + intro n + have hb := h_inv_acc_bnd n.val n.isLt + have hp := h_acc_bnd n + have hk_le : k * 2^25 ≤ K.val * 2^25 := Nat.mul_le_mul_right _ h_le + omega + obtain ⟨acc1, h_acc1_eq, h_acc1_bnd_rel, h_acc1_post⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_use_cache_poly_fc te1 t_r t_cache acc + h_te_bnd' h_t_r_bnd h_acc_cur_bnd h_cache_at_k) + -- (7) Body equation. + have h_body : + matrix.compute_ring_element_v_loop.body + (vectortraitsOperationsInst := portable_ops_inst) r_as_ntt cache + { iter := { cs := 384#usize, elements := rest }, count := cnt } t_as_ntt_entry acc + = .ok (ControlFlow.cont + ({ iter := { cs := 384#usize, elements := drop }, count := cnt' }, te1, acc1)) := by + unfold matrix.compute_ring_element_v_loop.body + rw [show + (CoreModels.core.iter.adapters.enumerate.Enumerate.Insts.CoreIterTraitsIteratorIteratorPairUsizeClause0_Item.next + (CoreModels.core.slice.iter.ChunksExact.Insts.CoreIterTraitsIteratorIteratorSharedASlice + Std.U8) + { iter := { cs := 384#usize, elements := rest }, count := cnt }) + = enumCENext { iter := { cs := 384#usize, elements := rest }, count := cnt } + from rfl] + rw [h_next_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let t_as_ntt_entry1 ← + libcrux_iot_ml_kem.serialize.deserialize_to_reduced_ring_element portable_ops_inst + chunk t_as_ntt_entry + let pre ← Aeneas.Std.Slice.index_usize r_as_ntt cnt + let pre1 ← Aeneas.Std.Slice.index_usize cache cnt + let accumulator1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache + portable_ops_inst t_as_ntt_entry1 pre acc pre1 + .ok (ControlFlow.cont + (({ iter := { cs := 384#usize, elements := drop }, count := cnt' } : EnumCE), + t_as_ntt_entry1, accumulator1))) + : Result (ControlFlow (EnumCE × (ChunkLoopFC.Poly × ChunkLoopFC.Acc)) + (ChunkLoopFC.Poly × ChunkLoopFC.Acc))) = _ + rw [h_te_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_r] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_cache] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_acc1_eq] + rfl + apply triple_of_ok_chunks h_body + -- (8) Discharge the step_post `.cont`. + show k < K.val ∧ (384#usize : Std.Usize) = 384#usize ∧ cnt'.val = k + 1 + ∧ drop.length = (K.val - (k + 1)) * 384 + ∧ (∀ ℓ : Nat, drop.val[ℓ]! = public_key.val[(k + 1) * 384 + ℓ]!) + ∧ (ChunkLoopFC.loop_inv trows r_arr acc_init (k + 1) (te1, acc1)).holds + refine ⟨h_lt, rfl, by rw [h_cnt'_val, h_cnt], ?_, ?_, ?_⟩ + · -- drop length. + have h384 : (384#usize : Std.Usize).val = 384 := rfl + rw [h_drop_len, h_rest_len, h384] + have hexp : (K.val - (k + 1)) * 384 = (K.val - k) * 384 - 384 := by + rw [Nat.sub_mul]; ring_nf; omega + rw [hexp, Nat.sub_mul] + · -- drop suffix relation at offset (k+1)*384. + intro ℓ + rw [h_drop_eq ℓ, h_suf] + have h384 : (384#usize : Std.Usize).val = 384 := rfl + rw [h384] + have hidx : k * 384 + (384 + ℓ) = (k + 1) * 384 + ℓ := by ring + rw [hidx] + · -- re-establish loop_inv at k+1. + set mp1 : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K := + mp.set cnt te1 with hmp1_def + have h_mp_len : mp.length = K.val := Std.Array.length_eq mp + have h_mp1_at : mp1.val[k]! = te1 := by + rw [hmp1_def, ← h_cnt] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq mp cnt cnt.val te1 + ⟨rfl, by rw [h_mp_len]; rw [h_cnt]; exact h_lt⟩ + have h_mp1_ne : ∀ j : Nat, j ≠ k → mp1.val[j]! = mp.val[j]! := by + intro j hj + rw [hmp1_def] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne mp cnt j te1 (fun h => hj (by rw [← h_cnt]; exact h.symm)) + have h_r_arr_k : r_arr.val[k]! = t_r := by + rw [ht_r_def]; exact h_r_arr k h_lt + have h_inv_pure : + (∃ mp' : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K, + (∀ c : Nat, c < k + 1 → + lift_poly (mp'.val[c]!) = trows.val[c]! + ∧ (∀ a : Fin 16, ∀ b : Fin 16, + ((mp'.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328)) + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = (List.range (k + 1)).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp'.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)))) + ∧ (∀ n : Nat, n < 256 → + (acc1.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + (k + 1) * 2^25) := by + refine ⟨⟨mp1, ?_, ?_⟩, ?_⟩ + · intro c hc + rcases Nat.lt_succ_iff_lt_or_eq.mp hc with hc_lt | hc_eq + · have hc_ne : c ≠ k := by omega + rw [h_mp1_ne c hc_ne]; exact h_mp_agree c hc_lt + · subst hc_eq + rw [h_mp1_at] + exact ⟨h_te_lift', h_te_bnd'⟩ + · intro j hj ℓ hℓ + have h_step_acc : + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (te1.coefficients.val[j]!)) + (lift_chunk_mont (t_r.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) := by + have := h_acc1_post + unfold accumulating_ntt_multiply_poly_post at this + exact this j hj ℓ hℓ + have h_ih := h_inv_acc j hj ℓ hℓ + rw [h_step_acc, h_ih] + rw [List.range_succ, List.foldl_append] + have h_foldl_congr : ∀ (L : List Nat) (init : hacspec_ml_kem.parameters.FieldElement), + (∀ c ∈ L, c < k) → + L.foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp1.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + init + = L.foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + init := by + intro L + induction L with + | nil => intro init _; rfl + | cons hd tl ih => + intro init hmem + have hhd : hd < k := hmem hd (List.mem_cons_self) + have htl : ∀ c ∈ tl, c < k := fun c hc => hmem c (List.mem_cons_of_mem hd hc) + have hhd_ne : hd ≠ k := by omega + simp only [List.foldl_cons] + rw [h_mp1_ne hd hhd_ne] + exact ih _ htl + rw [h_foldl_congr (List.range k) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + (fun c hc => by simpa using hc)] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((List.range k).foldl _ _) _ + = (List.foldl _ ((List.range k).foldl _ _) [k]) + rw [List.foldl_cons, List.foldl_nil] + rw [h_mp1_at, h_r_arr_k] + · intro n hn + have h_acc1_bnd_n := h_acc1_bnd_rel ⟨n, hn⟩ + have h_acc1_bnd_n' : (acc1.val[n]!).val.natAbs ≤ (acc.val[n]!).val.natAbs + 2^25 := + h_acc1_bnd_n + have h_inv_n := h_inv_acc_bnd n hn + have h_arith : (k + 1) * 2^25 = k * 2^25 + 2^25 := by ring + rw [h_arith] + linarith [h_acc1_bnd_n', h_inv_n] + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: rest.length < 384, k = K, done. + have hk_ge : ¬ k < K.val := h_lt + have hk_eq : k = K.val := by omega + have h_rest_lt : rest.length < (384#usize : Std.Usize).val := by + rw [h_rest_len, hk_eq] + have h384 : (384#usize : Std.Usize).val = 384 := rfl + simp [h384] + have h_next_eq := enumerate_chunks_next_done rest 384#usize cnt h_rest_lt + have h_body : + matrix.compute_ring_element_v_loop.body + (vectortraitsOperationsInst := portable_ops_inst) r_as_ntt cache + { iter := { cs := 384#usize, elements := rest }, count := cnt } t_as_ntt_entry acc + = .ok (ControlFlow.done (t_as_ntt_entry, acc)) := by + unfold matrix.compute_ring_element_v_loop.body + rw [show + (CoreModels.core.iter.adapters.enumerate.Enumerate.Insts.CoreIterTraitsIteratorIteratorPairUsizeClause0_Item.next + (CoreModels.core.slice.iter.ChunksExact.Insts.CoreIterTraitsIteratorIteratorSharedASlice + Std.U8) + { iter := { cs := 384#usize, elements := rest }, count := cnt }) + = enumCENext { iter := { cs := 384#usize, elements := rest }, count := cnt } + from rfl] + rw [h_next_eq] + rfl + apply triple_of_ok_chunks h_body + show (ChunkLoopFC.loop_inv trows r_arr acc_init K.val (t_as_ntt_entry, acc)).holds + have h_inv_pure : + (∃ mp' : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K, + (∀ c : Nat, c < K.val → + lift_poly (mp'.val[c]!) = trows.val[c]! + ∧ (∀ a : Fin 16, ∀ b : Fin 16, + ((mp'.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328)) + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp'.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)))) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + K.val * 2^25) := by + refine ⟨⟨mp, ?_, ?_⟩, ?_⟩ + · intro c hc; exact h_mp_agree c (by rw [hk_eq]; exact hc) + · intro j hj ℓ hℓ + have h_eq := h_inv_acc j hj ℓ hℓ + rw [show (List.range k) = (List.range K.val) by rw [hk_eq]] at h_eq + exact h_eq + · intro n hn + have h_b := h_inv_acc_bnd n hn + rw [show k * 2^25 = K.val * 2^25 by rw [hk_eq]] at h_b + exact h_b + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 4000000 in +/-- **L7.3 loop FC.** `matrix.compute_ring_element_v_loop`: the chunks-exact + USE-CACHE column loop. Iterates over `Enumerate (ChunksExact U8)` of the + public key; each chunk `i` deserializes `t̂[i]` (A2 axiom), reads + `r_as_ntt[i]` and `cache[i]` (read-only), and runs + `accumulating_ntt_multiply_use_cache` to add `t̂[i]·r[i]` to the I32 + accumulator. + + POST: the RESOLVED all-mont/existential `loop_inv` holds at k = K — there + exists a `K`-array `mp` of deserialized polys with `lift_poly mp[c] = + (lift_t_as_ntt_from_public_key public_key K).val[c]` (axiom-pinned) such + that for all (j,ℓ) ∈ [0,16)², `mont_reduce_pure (lift_fe_int acc[16j+ℓ])` + equals the K-column all-mont sum, plus the bound. + + Applies `loop_chunks_exact_pk_spec` (numChunks = K, cs = 384) with the + step lemma as `h_step`. -/ +theorem compute_ring_element_v_loop_fc (K : Std.Usize) (hK : K.val ≤ 4) + (public_key : Slice Std.U8) (h_pk_len : public_key.length = K.val * 384) + (t_as_ntt_entry : ChunkLoopFC.Poly) + (r_as_ntt cache : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (h_r_len : r_as_ntt.length = K.val) (h_cache_len : cache.length = K.val) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_cache_char : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (r_as_ntt.val[c]!) (cache.val[c]!)) + (accumulator : Std.Array Std.I32 256#usize) + (h_acc_zero : ∀ n : Nat, n < 256 → (accumulator.val[n]!).val = 0) + (iter0 : EnumCE) + (h_iter0 : iter0 = { iter := { cs := 384#usize, elements := public_key }, count := 0#usize }) : + ⦃ ⌜ True ⌝ ⦄ + matrix.compute_ring_element_v_loop + (vectortraitsOperationsInst := portable_ops_inst) iter0 + t_as_ntt_entry r_as_ntt cache accumulator + ⦃ ⇓ p => ⌜ (ChunkLoopFC.loop_inv (lift_t_as_ntt_from_public_key public_key K) r_arr accumulator + K.val p).holds ⌝ ⦄ := by + set trows : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + lift_t_as_ntt_from_public_key public_key K with htrows_def + -- accumulator budget: zero init + K·2^25 ≤ 2^30 (K ≤ 4). + have h_acc_bnd : ∀ n : Fin 256, + (accumulator.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30 := by + intro n + have hz := h_acc_zero n.val n.isLt + have : (accumulator.val[n.val]!).val.natAbs = 0 := by rw [hz]; rfl + rw [this] + have hk4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + omega + subst h_iter0 + unfold matrix.compute_ring_element_v_loop + apply Std.Do.Triple.of_entails_right _ + (loop_chunks_exact_pk_spec + (fun (iter1, p) => + matrix.compute_ring_element_v_loop.body + (vectortraitsOperationsInst := portable_ops_inst) r_as_ntt cache iter1 p.1 p.2) + (β := ChunkLoopFC.Poly × ChunkLoopFC.Acc) + (t_as_ntt_entry, accumulator) + public_key 384#usize K.val + (fun k p => ChunkLoopFC.loop_inv trows r_arr accumulator k p) + (by decide : 0 < (384#usize : Std.Usize).val) + (by rw [h_pk_len]; rfl) + (by + -- Base case at k = 0. + show (ChunkLoopFC.loop_inv trows r_arr accumulator 0 (t_as_ntt_entry, accumulator)).holds + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨⟨Std.Array.repeat K t_as_ntt_entry, ?_, ?_⟩, ?_⟩ + · intro c hc; exact absurd hc (Nat.not_lt_zero c) + · intro j hj ℓ hℓ + show Spec.mont_reduce_pure _ = (List.range 0).foldl _ _ + simp [List.range_zero, List.foldl_nil] + · intro n _; omega) + ?_) + · -- Post entailment at K. + rw [PostCond.entails_noThrow] + intro r hh + have h_holds : (ChunkLoopFC.loop_inv trows r_arr accumulator K.val r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + show (ChunkLoopFC.loop_inv trows r_arr accumulator K.val r).holds + exact h_holds + · -- Step entailment: apply the step lemma. + intro p k rest cnt hk_le hcnt hrest_len hsuf hinv + have h_step := compute_ring_element_v_loop_step_lemma_fc + public_key h_pk_len r_as_ntt cache r_arr accumulator + h_r_len h_cache_len h_r_arr h_r_bnd h_acc_bnd h_cache_char + p.1 p.2 k hk_le rest cnt hcnt hrest_len hsuf hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + have h384 : (384#usize : Std.Usize).val = 384 := rfl + rcases r with ⟨iter', acc'⟩ | y + · have hh' : k < K.val ∧ iter'.iter.cs = 384#usize ∧ iter'.count.val = k + 1 + ∧ iter'.iter.elements.length = (K.val - (k + 1)) * 384 + ∧ (∀ ℓ : Nat, iter'.iter.elements.val[ℓ]! + = public_key.val[(k + 1) * 384 + ℓ]!) + ∧ (ChunkLoopFC.loop_inv trows r_arr accumulator (k + 1) acc').holds := by + have := hh + simp only [Std.Do.SPred.down_pure] at this + exact this + obtain ⟨h_klt, h_cs, h_cnt', h_len', h_suf', h_inv'⟩ := hh' + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + refine ⟨h_klt, h_cs, h_cnt', ?_, ?_, h_inv'⟩ + · rw [h384]; exact h_len' + · intro ℓ; rw [h384]; exact h_suf' ℓ + · have h_done : (ChunkLoopFC.loop_inv trows r_arr accumulator K.val y).holds := by + have := hh + simp only [Std.Do.SPred.down_pure] at this + exact this + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact h_done + +end L7_3_irreducible + +/-! ## §L7.3 — acc-bridge (REUSES L7.4 `compute_message_acc_bridge`). -/ + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Sampling libcrux_iot_ml_kem.Serialize libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +/-- Local single-256-lane field-element poly abbrev (keeps `256#usize` out of + statement signatures — SKILL §7.7). -/ +private abbrev FEPoly := Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize + +/-- Local re-derivation of ComputeVectorU's private `lift_vec_mp_eq`: the + `lift_vec` of the existential witness `mp` collapses to the canonical rows + `trows`, given per-column agreement. -/ +private theorem lift_vec_mp_eq {K : Std.Usize} + (mp : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (trows : Std.Array FEPoly K) + (h_agree : ∀ c : Nat, c < K.val → lift_poly (mp.val[c]!) = trows.val[c]!) : + lift_vec mp = trows := by + apply Subtype.ext + show mp.val.map lift_poly = trows.val + have h_mp_len : mp.val.length = K.val := Std.Array.length_eq mp + have h_tr_len : trows.val.length = K.val := Std.Array.length_eq trows + apply List.ext_getElem + · rw [List.length_map, h_mp_len, h_tr_len] + · intro i hi1 _hi2 + have hi : i < K.val := by + have : i < (mp.val.map lift_poly).length := hi1 + rw [List.length_map, h_mp_len] at this; exact this + rw [List.getElem_map] + have h_lhs : lift_poly (mp.val[i]) = lift_poly (mp.val[i]!) := by + rw [getElem!_pos mp.val i (by rw [h_mp_len]; exact hi)] + have h_rhs : trows.val[i] = trows.val[i]! := by + rw [getElem!_pos trows.val i (by rw [h_tr_len]; exact hi)] + rw [h_lhs, h_rhs]; exact h_agree i hi + +/-- Local re-derivation of ComputeVectorU's private `lift_vec_r_arr_eq`. -/ +private theorem lift_vec_r_arr_eq {K : Std.Usize} + (r_as_ntt : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) : + lift_vec r_arr = lift_vec_slice r_as_ntt K := by + apply Subtype.ext + show r_arr.val.map lift_poly = (List.range K.val).map (fun i => lift_poly r_as_ntt.val[i]!) + have h_r_len : r_arr.val.length = K.val := Std.Array.length_eq r_arr + apply List.ext_getElem + · rw [List.length_map, h_r_len, List.length_map, List.length_range] + · intro i hi1 _hi2 + have hi : i < K.val := by + have : i < (r_arr.val.map lift_poly).length := hi1 + rw [List.length_map, h_r_len] at this; exact this + rw [List.getElem_map, List.getElem_map, List.getElem_range] + have h_lhs : lift_poly (r_arr.val[i]) = lift_poly (r_arr.val[i]!) := by + rw [getElem!_pos r_arr.val i (by rw [h_r_len]; exact hi)] + rw [h_lhs, h_r_arr i hi] + +set_option maxHeartbeats 1000000 in +/-- **L7.3 acc-bridge.** Reconciles the hacspec `multiply_vectors` of the + axiom-pinned deserialized t-as-ntt rows against the loop accumulator scaled + by `R = 2285`. A thin wrapper REUSING L7.4 `compute_message_acc_bridge`: the + existential witness `mp` of `loop_inv` supplies the t-as-ntt array, `r_arr` + the r-as-ntt array, and `loop_inv`'s two conjuncts are exactly + `S1LoopFC.loop_inv mp r_arr`'s two conjuncts. The two vector args are + rewritten via `lift_vec_mp_eq` / `lift_vec_r_arr_eq`. Mirrors + `compute_vector_u_rowi_acc_bridge`. -/ +theorem compute_ring_element_v_acc_bridge {K : Std.Usize} (hK : K.val ≤ 4) + (public_key : Slice Std.U8) + (r_as_ntt : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init acc2 : Std.Array Std.I32 256#usize) + (h_acc_init_zero : ∀ n : Nat, n < 256 → (acc_init.val[n]!).val = 0) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (t_ent : ChunkLoopFC.Poly) + (h_char : (ChunkLoopFC.loop_inv (lift_t_as_ntt_from_public_key public_key K) r_arr acc_init + K.val (t_ent, acc2)).holds) : + hacspec_ml_kem.matrix.multiply_vectors + (lift_t_as_ntt_from_public_key public_key K) (lift_vec_slice r_as_ntt K) + = .ok (scaleZ 2285 (Impl.mont_strip_pure + (Spec.poly_reducing_from_i32_array_pure (Aeneas.Std.Array.to_slice acc2)))) := by + set trows : Std.Array FEPoly K := lift_t_as_ntt_from_public_key public_key K with htrows_def + -- Destructure `loop_inv`'s 2 conjuncts; the first is the ∃-witness pack. + obtain ⟨⟨mp, h_mp_agree, h_inv_acc⟩, h_inv_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_char + dsimp only at h_inv_acc h_inv_bnd + -- `h_inv_acc` (mont foldl) and `h_inv_bnd` (bound) are exactly + -- `S1LoopFC.loop_inv mp r_arr acc_init K acc2`'s two conjuncts. + have h_char4 : (S1LoopFC.loop_inv mp r_arr acc_init K acc2).holds := by + show (pure + ((∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc2.val[16 * j + ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + ∧ (∀ n : Nat, n < 256 → + (acc2.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + K.val * 2^25)) + : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using + (⟨h_inv_acc, h_inv_bnd⟩ : _ ∧ _) + -- t-side bounds from the ∃-witness `mp`'s per-lane bound (conjunct 1.2). + have h_secret_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((mp.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro k i j; exact (h_mp_agree k.val k.isLt).2 i j + -- r-side bounds from `h_r_bnd` rewritten through `h_r_arr`. + have h_u_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((r_arr.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro k i j; rw [h_r_arr k.val k.isLt]; exact h_r_bnd k.val k.isLt i j + -- Apply the L7.4 bridge on `(mp, r_arr)`. + have h_bridge := + compute_message_acc_bridge mp r_arr acc_init acc2 h_acc_init_zero h_secret_bnd h_u_bnd h_char4 + have h_mp_vec : lift_vec mp = trows := + lift_vec_mp_eq mp trows (fun c hc => (h_mp_agree c hc).1) + have h_r_vec : lift_vec r_arr = lift_vec_slice r_as_ntt K := + lift_vec_r_arr_eq r_as_ntt r_arr h_r_arr + rw [h_mp_vec, h_r_vec] at h_bridge + rw [htrows_def] + exact h_bridge + +end libcrux_iot_ml_kem.Matrix.ComputeRingElementV.Impl \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeVectorU/FC.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeVectorU/FC.lean new file mode 100644 index 00000000..6df883c8 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeVectorU/FC.lean @@ -0,0 +1,1288 @@ +/- + # `Matrix/ComputeVectorU/FC.lean` — L7.2 FC theorem glue. + + Houses the L7.2 FC theorem `compute_vector_u_fc` — the top-level + assembly gluing the impl walk (`compute_vector_u_loop0` row-0 column + loop + `compute_vector_u_loop1` outer rows loop, both proven in + `Impl/ComputeVectorU.lean`) to the hacspec `matrix.compute_vector_u` + (transpose → multiply_matrix_by_column → createi(ntt_inverse) → + add_vectors). + + Two added subtleties vs L7.4: + * a vector output (K rows, two accumulation loops: row 0 fills the cache; + rows 1..K consume it); + * a hacspec part-A reduction `compute_vector_u_hacspec_eq` relating the + hacspec do-block to the per-row `AllRowsFillFC.row_spec`. Because the + Hacspec `extractCol`/`mcol_*` helpers are `private`, the + matrix-column reduction is re-derived locally here. + + PRE strengthening: `hK ≤ 4`, lengths, per-lane bounds. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeMessage.Impl +import LibcruxIotMlKem.Matrix.ComputeVectorU.Impl +import LibcruxIotMlKem.Matrix.ComputeMessage.Hacspec +import LibcruxIotMlKem.Matrix.ComputeVectorU.Hacspec + +namespace libcrux_iot_ml_kem.Matrix.ComputeVectorU.FC +open libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges libcrux_iot_ml_kem.Matrix.ComputeMessage.Hacspec libcrux_iot_ml_kem.Matrix.ComputeMessage.Impl libcrux_iot_ml_kem.Matrix.ComputeVectorU.Hacspec libcrux_iot_ml_kem.Matrix.ComputeVectorU.Impl +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Sampling libcrux_iot_ml_kem.Serialize libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +set_option mvcgen.warning false +set_option linter.unusedVariables false + +/-- Local copy of the `private triple_exists_ok_fc` helper. -/ +private theorem triple_exists_ok_fc {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- Local copy of the `private triple_of_ok_fc` helper. -/ +private theorem triple_of_ok_fc {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +/-! ## PART A — hacspec reduction. + + The Hacspec `extractCol`/`mcol_*` helpers + the public + `multiply_matrix_by_column_at_eq_multiply_vectors` are usable only modulo + the `private` `extractCol` they refer to. Since we cannot name/unfold the + private `extractCol`, we RE-DERIVE the matrix-column reduction locally with + a public-in-this-module copy `extractColL`, mirroring Hacspec + `multiply_matrix_by_column_at_eq_mcol` / `multiply_vectors_eq_mcol`. -/ + +section PartA + +open hacspec_ml_kem.parameters (FieldElement) +open Result ControlFlow + +/-- Polynomial as a 256-lane field-element array. -/ +private abbrev Poly256L := Std.Array FieldElement 256#usize + +/-- Local copy of Hacspec `extractCol` (private there). -/ +private noncomputable def extractColL {K : Std.Usize} + (m : Std.Array (Std.Array Poly256L K) K) + (i : Std.Usize) : Std.Array Poly256L K := + Std.Array.make K ((List.range K.val).map (fun j => (m.val[j]!).val[i.val]!)) + (by simp [List.length_map, List.length_range]) + +private theorem extractColL_lane {K : Std.Usize} + (m : Std.Array (Std.Array Poly256L K) K) + (i : Std.Usize) (j : Nat) (hj : j < K.val) : + (extractColL m i).val[j]! = (m.val[j]!).val[i.val]! := by + unfold extractColL + show ((List.range K.val).map (fun j => (m.val[j]!).val[i.val]!))[j]! = _ + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + simp [List.getElem_map, List.getElem_range] + +/-- Local copy of Hacspec `mcol_lane_at_step`. -/ +private noncomputable def mcolLane {K : Std.Usize} + (col vec : Std.Array Poly256L K) (k : Nat) (ℓ : Nat) : FieldElement := + (List.range k).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.multiply_ntts_pure (col.val[c]!) (vec.val[c]!)).val[ℓ]!)) + ({ val := 0#u16 } : FieldElement) + +private noncomputable def mcolResult {K : Std.Usize} + (col vec : Std.Array Poly256L K) (k : Nat) : Poly256L := + ⟨(List.range 256).map (fun ℓ => mcolLane col vec k ℓ), + by simp [List.length_map, List.length_range]⟩ + +private theorem mcolResult_val_lane {K : Std.Usize} + (col vec : Std.Array Poly256L K) (k : Nat) (ℓ : Nat) (hℓ : ℓ < 256) : + (mcolResult col vec k).val[ℓ]! = mcolLane col vec k ℓ := by + unfold mcolResult + show ((List.range 256).map (fun ℓ' => mcolLane col vec k ℓ'))[ℓ]! = _ + rw [getElem!_pos _ ℓ (by simp [List.length_map, List.length_range, hℓ])] + rw [List.getElem_map, List.getElem_range] + +private theorem mcolLane_zero {K : Std.Usize} + (col vec : Std.Array Poly256L K) (ℓ : Nat) : + mcolLane col vec 0 ℓ = ({ val := 0#u16 } : FieldElement) := by + unfold mcolLane; rw [List.range_zero, List.foldl_nil] + +private theorem mcolLane_succ {K : Std.Usize} + (col vec : Std.Array Poly256L K) (k : Nat) (ℓ : Nat) : + mcolLane col vec (k + 1) ℓ + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (mcolLane col vec k ℓ) + ((Spec.multiply_ntts_pure (col.val[k]!) (vec.val[k]!)).val[ℓ]!) := by + unfold mcolLane + rw [List.range_succ, List.foldl_append, List.foldl_cons, List.foldl_nil] + +private theorem mcol_mult_eq (a1 a2 : Poly256L) : + hacspec_ml_kem.ntt.multiply_ntts a1 a2 = .ok (Spec.multiply_ntts_pure a1 a2) := by + unfold Spec.multiply_ntts_pure + rw [HelpersFC.multiply_ntts_eq_pure_array] + +private theorem mcol_step_add_eq {K : Std.Usize} + (col vec : Std.Array Poly256L K) (k : Nat) : + hacspec_ml_kem.matrix.add_polynomials (mcolResult col vec k) + (Spec.multiply_ntts_pure (col.val[k]!) (vec.val[k]!)) + = .ok (mcolResult col vec (k + 1)) := by + rw [Stage4MatrixAddFC.matrix_add_polynomials_eq_ok] + apply congrArg Result.ok + apply Subtype.ext + show (List.range 256).map (fun n => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (mcolResult col vec k).val[n]! + (Spec.multiply_ntts_pure (col.val[k]!) (vec.val[k]!)).val[n]!) + = (mcolResult col vec (k + 1)).val + unfold mcolResult + apply List.map_congr_left + intro n hn_mem + have hn_lt : n < 256 := List.mem_range.mp hn_mem + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (mcolResult col vec k).val[n]! + (Spec.multiply_ntts_pure (col.val[k]!) (vec.val[k]!)).val[n]! + = mcolLane col vec (k + 1) n + rw [mcolResult_val_lane _ _ _ _ hn_lt, mcolLane_succ] + +set_option maxHeartbeats 16000000 in +set_option maxRecDepth 1000 in +/-- Local copy of Hacspec `multiply_vectors_eq_mcol`. -/ +private theorem multiply_vectors_eq_mcolL {K : Std.Usize} + (col vec : Std.Array Poly256L K) : + hacspec_ml_kem.matrix.multiply_vectors col vec + = .ok (mcolResult col vec K.val) := by + unfold hacspec_ml_kem.matrix.multiply_vectors + unfold hacspec_ml_kem.parameters.FieldElement.new + simp only [bind_tc_ok] + have h_triple : ⦃ ⌜ True ⌝ ⦄ + hacspec_ml_kem.matrix.multiply_vectors_loop + ({ start := 0#usize, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + col vec + (Std.Array.repeat (256#usize : Std.Usize) ({ val := 0#u16 } : FieldElement)) + ⦃ ⇓ r => ⌜ r = mcolResult col vec K.val ⌝ ⦄ := by + unfold hacspec_ml_kem.matrix.multiply_vectors_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun p : CoreModels.core.ops.range.Range Std.Usize × Poly256L => + hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec p.1 p.2) + (β := Poly256L) + (Std.Array.repeat (256#usize : Std.Usize) ({ val := 0#u16 } : FieldElement)) + 0#usize K + (fun k result => pure (result = mcolResult col vec k.val)) + (Nat.zero_le _) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + apply Subtype.ext + rw [Std.Array.repeat_val] + unfold mcolResult + show List.replicate 256 _ = (List.range 256).map _ + apply List.ext_getElem + · rw [List.length_replicate, List.length_map, List.length_range] + intro n h_n_lhs _ + have h_n_lt : n < 256 := by + rw [List.length_replicate] at h_n_lhs; exact h_n_lhs + rw [List.getElem_replicate, List.getElem_map, List.getElem_range] + show _ = mcolLane col vec 0 n + rw [mcolLane_zero]) + ?_) + · rw [PostCond.entails_noThrow] + intro r hh + have h_eq : (pure (r = mcolResult col vec K.val) : Result Prop).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_eq + · intro acc k _h_ge h_le hinv + have h_acc_eq : acc = mcolResult col vec k.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hinv + subst h_acc_eq + unfold hacspec_ml_kem.matrix.multiply_vectors_loop.body + by_cases h_lt : k.val < K.val + · have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + have hlen_col : col.length = K.val := Std.Array.length_eq col + have hlen_vec : vec.length = K.val := Std.Array.length_eq vec + have h_idx_a1 : Aeneas.Std.Array.index_usize col k = .ok (col.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq col k + (by rw [hlen_col]; exact h_lt) + have h_idx_a2 : Aeneas.Std.Array.index_usize vec k = .ok (vec.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec k + (by rw [hlen_vec]; exact h_lt) + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × Poly256L => + hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec p.1 p.2) + ({ start := k, «end» := K }, mcolResult col vec k.val) + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + mcolResult col vec (k.val + 1))) := by + show hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec + { start := k, «end» := K } (mcolResult col vec k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_vectors_loop.body + conv_lhs => + rw [show + (CoreModels.core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let a ← Aeneas.Std.Array.index_usize col k + let a1' ← Aeneas.Std.Array.index_usize vec k + let product ← hacspec_ml_kem.ntt.multiply_ntts a a1' + let result1 ← hacspec_ml_kem.matrix.add_polynomials + (mcolResult col vec k.val) product + Aeneas.Std.Result.ok (ControlFlow.cont + (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), result1))) + : Result _) = _ + rw [h_idx_a1] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a2] + simp only [Aeneas.Std.bind_tc_ok] + rw [mcol_mult_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [mcol_step_add_eq] + simp only [Aeneas.Std.bind_tc_ok] + apply triple_of_ok_fc h_body + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + show (pure (mcolResult col vec (k.val + 1) + = mcolResult col vec s_iter.val) : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hs_iter_val] + rfl + · have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × Poly256L => + hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec p.1 p.2) + ({ start := k, «end» := K }, mcolResult col vec k.val) + = .ok (ControlFlow.done (mcolResult col vec k.val)) := by + show hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec + { start := k, «end» := K } (mcolResult col vec k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_vectors_loop.body + conv_lhs => + rw [show + (CoreModels.core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show (pure (mcolResult col vec k.val + = mcolResult col vec K.val) : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hk_eq] + rfl + obtain ⟨v, hv_eq, hv_post⟩ := triple_exists_ok_fc h_triple + rw [hv_eq, hv_post] + +set_option maxHeartbeats 16000000 in +set_option maxRecDepth 1000 in +/-- Local copy of Hacspec `multiply_matrix_by_column_at_eq_mcol`. -/ +private theorem mmbc_at_eq_mcolL {K : Std.Usize} + (m : Std.Array (Std.Array Poly256L K) K) + (vec : Std.Array Poly256L K) (i : Std.Usize) + (hi : i.val < K.val) : + hacspec_ml_kem.matrix.multiply_matrix_by_column_at m vec i + = .ok (mcolResult (extractColL m i) vec K.val) := by + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at + unfold hacspec_ml_kem.parameters.FieldElement.new + simp only [bind_tc_ok] + have h_triple : ⦃ ⌜ True ⌝ ⦄ + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop + ({ start := 0#usize, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + m vec i + (Std.Array.repeat (256#usize : Std.Usize) ({ val := 0#u16 } : FieldElement)) + ⦃ ⇓ r => ⌜ r = mcolResult (extractColL m i) vec K.val ⌝ ⦄ := by + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun p : CoreModels.core.ops.range.Range Std.Usize × Poly256L => + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i p.1 p.2) + (β := Poly256L) + (Std.Array.repeat (256#usize : Std.Usize) ({ val := 0#u16 } : FieldElement)) + 0#usize K + (fun k result => pure (result = mcolResult (extractColL m i) vec k.val)) + (Nat.zero_le _) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + apply Subtype.ext + rw [Std.Array.repeat_val] + unfold mcolResult + show List.replicate 256 _ = (List.range 256).map _ + apply List.ext_getElem + · rw [List.length_replicate, List.length_map, List.length_range] + intro n h_n_lhs _ + have h_n_lt : n < 256 := by + rw [List.length_replicate] at h_n_lhs; exact h_n_lhs + rw [List.getElem_replicate, List.getElem_map, List.getElem_range] + show _ = mcolLane (extractColL m i) vec 0 n + rw [mcolLane_zero]) + ?_) + · rw [PostCond.entails_noThrow] + intro r hh + have h_eq : (pure (r = mcolResult (extractColL m i) vec K.val) + : Result Prop).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_eq + · intro acc k _h_ge h_le hinv + have h_acc_eq : acc = mcolResult (extractColL m i) vec k.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hinv + subst h_acc_eq + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + by_cases h_lt : k.val < K.val + · have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + have hlen_m : m.length = K.val := Std.Array.length_eq m + have hlen_vec : vec.length = K.val := Std.Array.length_eq vec + have hlen_mk : (m.val[k.val]!).length = K.val := Std.Array.length_eq (m.val[k.val]!) + have h_idx_mk : Aeneas.Std.Array.index_usize m k = .ok (m.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq m k + (by rw [hlen_m]; exact h_lt) + have h_col_eq : (extractColL m i).val[k.val]! = (m.val[k.val]!).val[i.val]! := + extractColL_lane m i k.val h_lt + have h_idx_a1 : + Aeneas.Std.Array.index_usize (m.val[k.val]!) i = .ok ((extractColL m i).val[k.val]!) := by + rw [h_col_eq] + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq (m.val[k.val]!) i + (by rw [hlen_mk]; exact hi) + have h_idx_a2 : Aeneas.Std.Array.index_usize vec k = .ok (vec.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec k + (by rw [hlen_vec]; exact h_lt) + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × Poly256L => + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i p.1 p.2) + ({ start := k, «end» := K }, mcolResult (extractColL m i) vec k.val) + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + mcolResult (extractColL m i) vec (k.val + 1))) := by + show hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i + { start := k, «end» := K } + (mcolResult (extractColL m i) vec k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + conv_lhs => + rw [show + (CoreModels.core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let a ← Aeneas.Std.Array.index_usize m k + let a1 ← Aeneas.Std.Array.index_usize a i + let a2 ← Aeneas.Std.Array.index_usize vec k + let product ← hacspec_ml_kem.ntt.multiply_ntts a1 a2 + let result1 ← hacspec_ml_kem.matrix.add_polynomials + (mcolResult (extractColL m i) vec k.val) product + Aeneas.Std.Result.ok (ControlFlow.cont + (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), result1))) + : Result _) = _ + rw [h_idx_mk] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a1] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a2] + simp only [Aeneas.Std.bind_tc_ok] + rw [mcol_mult_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [mcol_step_add_eq] + simp only [Aeneas.Std.bind_tc_ok] + apply triple_of_ok_fc h_body + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + show (pure (mcolResult (extractColL m i) vec (k.val + 1) + = mcolResult (extractColL m i) vec s_iter.val) + : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hs_iter_val] + rfl + · have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × Poly256L => + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i p.1 p.2) + ({ start := k, «end» := K }, mcolResult (extractColL m i) vec k.val) + = .ok (ControlFlow.done (mcolResult (extractColL m i) vec k.val)) := by + show hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i + { start := k, «end» := K } + (mcolResult (extractColL m i) vec k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + conv_lhs => + rw [show + (CoreModels.core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show (pure (mcolResult (extractColL m i) vec k.val + = mcolResult (extractColL m i) vec K.val) + : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hk_eq] + rfl + obtain ⟨v, hv_eq, hv_post⟩ := triple_exists_ok_fc h_triple + rw [hv_eq, hv_post] + +/-- `multiply_matrix_by_column_at m vec i = multiply_vectors (extractColL m i) vec`. -/ +private theorem mmbc_at_eq_multiply_vectorsL {K : Std.Usize} + (m : Std.Array (Std.Array Poly256L K) K) + (vec : Std.Array Poly256L K) + (i : Std.Usize) (hi : i.val < K.val) : + hacspec_ml_kem.matrix.multiply_matrix_by_column_at m vec i + = hacspec_ml_kem.matrix.multiply_vectors (extractColL m i) vec := by + rw [mmbc_at_eq_mcolL m vec i hi, multiply_vectors_eq_mcolL] + +/-! ### transpose reduction. -/ + +/-- `BitVec.ofNat.toNat = k` for `k < K.val`. -/ +private theorem ofNat_toNat_eq {K : Std.Usize} (k : Nat) (hk : k < K.val) : + (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + have hK_lt : K.val < 2^System.Platform.numBits := by + have h := K.hBounds + simp at h + omega + exact Nat.lt_of_lt_of_le hk (Nat.le_of_lt hK_lt) + +/-- The inner transpose closure (`createi RANK closure (m,j)`) builds the + column-`j` vector: lane `i'` is `m[i'][j]`. -/ +private theorem transpose_inner_eq {K : Std.Usize} + (m : Std.Array (Std.Array Poly256L K) K) (j : Std.Usize) (hj : j.val < K.val) : + hacspec_ml_kem.matrix.transpose.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayArrayFieldElement256RANK.call + (RANK := K) (m) j + = .ok ⟨(List.range K.val).map (fun i' => (m.val[i']!).val[j.val]!), + by simp [List.length_map, List.length_range]⟩ := by + unfold hacspec_ml_kem.matrix.transpose.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayArrayFieldElement256RANK.call + show hacspec_ml_kem.parameters.createi K + (hacspec_ml_kem.matrix.transpose.closure.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256 K) + (m, j) = _ + unfold hacspec_ml_kem.parameters.createi + have hpure : ∀ i' : Nat, i' < K.val → + (hacspec_ml_kem.matrix.transpose.closure.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256 K).FnMutInst.call_mut + (m, j) ⟨BitVec.ofNat _ i'⟩ + = .ok ((m.val[i']!).val[j.val]!, (m, j)) := by + intro i' hi' + show hacspec_ml_kem.matrix.transpose.closure.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + (m, j) ⟨BitVec.ofNat _ i'⟩ + = .ok ((m.val[i']!).val[j.val]!, (m, j)) + unfold hacspec_ml_kem.matrix.transpose.closure.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + unfold hacspec_ml_kem.matrix.transpose.closure.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256.call + have hi'_val : (⟨BitVec.ofNat _ i'⟩ : Std.Usize).val = i' := ofNat_toNat_eq i' hi' + have hlen_m : m.length = K.val := Std.Array.length_eq m + have hlen_mi : (m.val[i']!).length = K.val := Std.Array.length_eq (m.val[i']!) + have h_idx_m : Std.Array.index_usize m (⟨BitVec.ofNat _ i'⟩ : Std.Usize) + = .ok (m.val[i']!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq m (⟨BitVec.ofNat _ i'⟩ : Std.Usize) + (by rw [hi'_val, hlen_m]; exact hi') + rw [hi'_val] at this; exact this + have h_idx_mi : Std.Array.index_usize (m.val[i']!) j + = .ok ((m.val[i']!).val[j.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq (m.val[i']!) j + (by rw [hlen_mi]; exact hj) + change (do + let a ← (do + let a1 ← Std.Array.index_usize m (⟨BitVec.ofNat _ i'⟩ : Std.Usize) + Std.Array.index_usize a1 j) + .ok (a, m, j)) = .ok ((m.val[i']!).val[j.val]!, (m, j)) + rw [h_idx_m]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_mi]; simp only [Aeneas.Std.bind_tc_ok] + have h := libcrux_iot_ml_kem.Util.CreateI.createi_pure_eq + (T := Poly256L) + (F := hacspec_ml_kem.matrix.transpose.closure.closure K) + (N := K) + (inst := hacspec_ml_kem.matrix.transpose.closure.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256 K) + (c := (m, j)) + (f := fun i' => (m.val[i']!).val[j.val]!) + hpure + exact h + +/-- `(transpose m).val[j]! = column j of m` (for `j < K`). -/ +private theorem transpose_row_eq {K : Std.Usize} + (m : Std.Array (Std.Array Poly256L K) K) : + hacspec_ml_kem.matrix.transpose m + = .ok ⟨(List.range K.val).map (fun j => + (⟨(List.range K.val).map (fun i' => (m.val[i']!).val[j]!), + by simp [List.length_map, List.length_range]⟩ : Std.Array Poly256L K)), + by simp [List.length_map, List.length_range]⟩ := by + unfold hacspec_ml_kem.matrix.transpose + unfold hacspec_ml_kem.parameters.createi + set f : Nat → Std.Array Poly256L K := + fun j => ⟨(List.range K.val).map (fun i' => (m.val[i']!).val[j]!), + by simp [List.length_map, List.length_range]⟩ with hf_def + have hpure : ∀ j : Nat, j < K.val → + (hacspec_ml_kem.matrix.transpose.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayArrayFieldElement256RANK K).FnMutInst.call_mut + (m) ⟨BitVec.ofNat _ j⟩ + = .ok (f j, m) := by + intro j hj + show hacspec_ml_kem.matrix.transpose.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayArrayFieldElement256RANK.call_mut + (m) ⟨BitVec.ofNat _ j⟩ = .ok (f j, m) + unfold hacspec_ml_kem.matrix.transpose.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayArrayFieldElement256RANK.call_mut + have hj_val : (⟨BitVec.ofNat _ j⟩ : Std.Usize).val = j := ofNat_toNat_eq j hj + have h_inner := transpose_inner_eq m (⟨BitVec.ofNat _ j⟩ : Std.Usize) (by rw [hj_val]; exact hj) + show (do let a ← hacspec_ml_kem.matrix.transpose.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayArrayFieldElement256RANK.call + (RANK := K) m (⟨BitVec.ofNat _ j⟩ : Std.Usize) + .ok (a, m)) = .ok (f j, m) + rw [h_inner]; simp only [Aeneas.Std.bind_tc_ok] + rw [hf_def, hj_val] + have h := libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq + (T := Std.Array Poly256L K) + (F := hacspec_ml_kem.matrix.transpose.closure K) + (N := K) + (inst := (hacspec_ml_kem.matrix.transpose.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayArrayFieldElement256RANK K).FnMutInst) + (c := m) + (f := f) + hpure + show core.array.from_fn K _ _ = _ + rw [h] + +/-- `extractColL (transpose m) i = m.val[i]!` for `i < K`. -/ +private theorem extractColL_transpose_eq {K : Std.Usize} + (m : Std.Array (Std.Array Poly256L K) K) (i : Std.Usize) (hi : i.val < K.val) : + ∃ T, hacspec_ml_kem.matrix.transpose m = .ok T ∧ extractColL T i = m.val[i.val]! := by + refine ⟨_, transpose_row_eq m, ?_⟩ + set T : Std.Array (Std.Array Poly256L K) K := + ⟨(List.range K.val).map (fun j => + (⟨(List.range K.val).map (fun i' => (m.val[i']!).val[j]!), + by simp [List.length_map, List.length_range]⟩ : Std.Array Poly256L K)), + by simp [List.length_map, List.length_range]⟩ with hT_def + -- T.val[j]!.val[i]! = m.val[i]!.val[j]! + have hT_at : ∀ j : Nat, j < K.val → + (T.val[j]!).val[i.val]! = (m.val[i.val]!).val[j]! := by + intro j hj + have h1 : T.val[j]! = (⟨(List.range K.val).map (fun i' => (m.val[i']!).val[j]!), + by simp [List.length_map, List.length_range]⟩ : Std.Array Poly256L K) := by + rw [hT_def] + show ((List.range K.val).map _)[j]! = _ + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + rw [h1] + show ((List.range K.val).map (fun i' => (m.val[i']!).val[j]!))[i.val]! = _ + rw [getElem!_pos _ i.val (by simp [List.length_map, List.length_range, hi])] + rw [List.getElem_map, List.getElem_range] + -- extractColL T i = m[i] by Subtype.ext + List.ext_getElem. + apply Subtype.ext + show ((List.range K.val).map (fun j => (T.val[j]!).val[i.val]!)) = (m.val[i.val]!).val + have hmi_len : (m.val[i.val]!).val.length = K.val := by + have := Std.Array.length_eq (m.val[i.val]!) + exact this + apply List.ext_getElem + · rw [List.length_map, List.length_range, hmi_len] + · intro j hj1 _ + have hj : j < K.val := by + have : j < ((List.range K.val).map _).length := hj1 + simpa [List.length_map, List.length_range] using this + rw [List.getElem_map, List.getElem_range] + rw [hT_at j hj] + rw [getElem!_pos (m.val[i.val]!).val j (by rw [hmi_len]; exact hj)] + +/-! ### createi-stage reduction for the hacspec `compute_vector_u`. -/ + +/-- `Result.ok`-extraction (default on non-ok); used to give `createi` a total + index function from per-index `.ok`-ness. -/ +private def resGet {α : Type} [Inhabited α] : Result α → α + | .ok v => v + | _ => default + +private theorem resGet_ok {α : Type} [Inhabited α] {x : Result α} {v : α} + (h : x = .ok v) : x = .ok (resGet x) := by + rw [h]; rfl + +/-- A `createi` whose per-index closure op is `g i` (when `g i = .ok …`) yields + `.ok ⟨map (resGet ∘ g)⟩`. Specialized to closures of the shape used by the + hacspec `compute_vector_u` stages (mmbc / ntt_inverse / add_vectors), where + `call_mut c ⟨k⟩ = (do let a ← g k; ok (a, c))`. -/ +private theorem createi_stage_eq {K : Std.Usize} {T F : Type} [Inhabited T] + (inst : CoreModels.core.ops.function.Fn F Std.Usize T) (c : F) + (g : Nat → Result T) + (hcall : ∀ k : Nat, k < K.val → + inst.FnMutInst.call_mut c ⟨BitVec.ofNat _ k⟩ = (do let a ← g k; .ok (a, c))) + (hok : ∀ k : Nat, k < K.val → ∃ v, g k = .ok v) : + hacspec_ml_kem.parameters.createi K inst c + = .ok ⟨(List.range K.val).map (fun k => resGet (g k)), + by simp [List.length_map, List.length_range]⟩ := by + unfold hacspec_ml_kem.parameters.createi + have hpure : ∀ k : Nat, k < K.val → + inst.FnMutInst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (resGet (g k), c) := by + intro k hk + rw [hcall k hk] + obtain ⟨v, hv⟩ := hok k hk + rw [hv]; rfl + exact libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq K inst.FnMutInst c + (fun k => resGet (g k)) hpure + +private theorem array_make_lane {K : Std.Usize} {α : Type} [Inhabited α] + (l : List α) (h : l.length = K.val) (i : Nat) (hi : i < K.val) : + (Std.Array.make K l h).val[i]! = l[i]! := rfl + +/-- Lane of a `(List.range K).map f` array at `i < K`. -/ +private theorem range_map_lane {α : Type} [Inhabited α] + (K : Nat) (f : Nat → α) (i : Nat) (hi : i < K) : + ((List.range K).map f)[i]! = f i := by + rw [getElem!_pos _ i (by simp [List.length_map, List.length_range, hi])] + rw [List.getElem_map, List.getElem_range] + +set_option maxHeartbeats 1600000 in +/-- **PART A.** The hacspec `compute_vector_u lm rvec evec` reduces to a vector + `W` whose row `i` is the per-row `multiply_vectors lm[i] rvec → ntt_inverse → + add_polynomials (·) evec[i]` chain. Stages reduced via `createi_stage_eq`. -/ +private theorem compute_vector_u_hacspec_eq {K : Std.Usize} + (lm : Std.Array (Std.Array Poly256L K) K) + (rvec evec : Std.Array Poly256L K) + (W : Std.Array Poly256L K) + (hWlen : W.length = K.val) + (hrow : ∀ i : Nat, i < K.val → + (do + let prod ← hacspec_ml_kem.matrix.multiply_vectors (lm.val[i]!) rvec + let inv ← hacspec_ml_kem.invert_ntt.ntt_inverse prod + hacspec_ml_kem.matrix.add_polynomials inv (evec.val[i]!)) + = .ok (W.val[i]!)) : + hacspec_ml_kem.matrix.compute_vector_u lm rvec evec = .ok W := by + unfold hacspec_ml_kem.matrix.compute_vector_u + -- per-row witnesses: prodᵢ, invᵢ from `hrow`. + have hprod_ok : ∀ i : Nat, i < K.val → + ∃ p, hacspec_ml_kem.matrix.multiply_vectors (lm.val[i]!) rvec = .ok p := by + intro i hi + have h := hrow i hi + match hmv : hacspec_ml_kem.matrix.multiply_vectors (lm.val[i]!) rvec with + | .ok p => exact ⟨p, rfl⟩ + | .fail e => rw [hmv] at h; simp only [Aeneas.Std.bind_tc_fail] at h; exact absurd h (by simp) + | .div => rw [hmv] at h; simp only [Aeneas.Std.bind_tc_div] at h; exact absurd h (by simp) + -- Stage 0: transpose lm = .ok T, extractColL T i = lm[i]. + -- (transpose_row_eq gives a concrete T; we use a generic T below.) + -- Stage 1: multiply_matrix_by_column T rvec. + set T : Std.Array (Std.Array Poly256L K) K := + ⟨(List.range K.val).map (fun j => + (⟨(List.range K.val).map (fun i' => (lm.val[i']!).val[j]!), + by simp [List.length_map, List.length_range]⟩ : Std.Array Poly256L K)), + by simp [List.length_map, List.length_range]⟩ with hT_def + have hT_eq : hacspec_ml_kem.matrix.transpose lm = .ok T := transpose_row_eq lm + have hcolT : ∀ i : Nat, i < K.val → + extractColL T (⟨BitVec.ofNat _ i⟩ : Std.Usize) = lm.val[i]! := by + intro i hi + obtain ⟨T', hT'_eq, hcol⟩ := extractColL_transpose_eq lm (⟨BitVec.ofNat _ i⟩ : Std.Usize) + (by rw [ofNat_toNat_eq i hi]; exact hi) + rw [hT_eq] at hT'_eq + have : T' = T := (Result.ok.inj hT'_eq).symm + rw [this] at hcol + rw [hcol, ofNat_toNat_eq i hi] + rw [hT_eq]; simp only [Aeneas.Std.bind_tc_ok] + -- product := multiply_matrix_by_column T rvec, reduced via createi_stage_eq. + set g_prod : Nat → Result Poly256L := + fun k => hacspec_ml_kem.matrix.multiply_matrix_by_column_at T rvec ⟨BitVec.ofNat _ k⟩ + with hg_prod_def + have hg_prod_ok : ∀ k : Nat, k < K.val → ∃ v, g_prod k = .ok v := by + intro k hk + show ∃ v, hacspec_ml_kem.matrix.multiply_matrix_by_column_at T rvec + (⟨BitVec.ofNat _ k⟩ : Std.Usize) = .ok v + have hkv : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := ofNat_toNat_eq k hk + rw [mmbc_at_eq_multiply_vectorsL T rvec ⟨BitVec.ofNat _ k⟩ (by rw [hkv]; exact hk)] + rw [hcolT k hk] + exact hprod_ok k hk + have h_prod_stage : + hacspec_ml_kem.matrix.multiply_matrix_by_column T rvec + = .ok ⟨(List.range K.val).map (fun k => resGet (g_prod k)), + by simp [List.length_map, List.length_range]⟩ := by + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column + apply createi_stage_eq _ _ g_prod _ hg_prod_ok + intro k hk + show hacspec_ml_kem.matrix.multiply_matrix_by_column.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + (T, rvec) ⟨BitVec.ofNat _ k⟩ = (do let a ← g_prod k; .ok (a, (T, rvec))) + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256.call + rfl + rw [h_prod_stage]; simp only [Aeneas.Std.bind_tc_ok] + set P : Std.Array Poly256L K := + ⟨(List.range K.val).map (fun k => resGet (g_prod k)), + by simp [List.length_map, List.length_range]⟩ with hP_def + have hP_at : ∀ i : Nat, i < K.val → P.val[i]! = resGet (g_prod i) := by + intro i hi; rw [hP_def]; exact range_map_lane K.val _ i hi + -- Stage 2: createi(ntt_inverse) P. + set g_inv : Nat → Result Poly256L := + fun k => hacspec_ml_kem.invert_ntt.ntt_inverse (P.val[k]!) with hg_inv_def + have hg_inv_ok : ∀ k : Nat, k < K.val → ∃ v, g_inv k = .ok v := by + intro k hk + show ∃ v, hacspec_ml_kem.invert_ntt.ntt_inverse (P.val[k]!) = .ok v + rw [hP_at k hk] + -- resGet (g_prod k) = prodₖ, and ntt_inverse prodₖ succeeds by hrow. + obtain ⟨p, hp⟩ := hprod_ok k hk + have hgp : g_prod k = .ok p := by + show hacspec_ml_kem.matrix.multiply_matrix_by_column_at T rvec (⟨BitVec.ofNat _ k⟩ : Std.Usize) = .ok p + rw [mmbc_at_eq_multiply_vectorsL T rvec ⟨BitVec.ofNat _ k⟩ + (by rw [ofNat_toNat_eq k hk]; exact hk), hcolT k hk]; exact hp + rw [hgp] + show ∃ v, hacspec_ml_kem.invert_ntt.ntt_inverse p = .ok v + have h := hrow k hk + rw [hp] at h; simp only [Aeneas.Std.bind_tc_ok] at h + match hni : hacspec_ml_kem.invert_ntt.ntt_inverse p with + | .ok q => exact ⟨q, rfl⟩ + | .fail e => rw [hni] at h; simp only [Aeneas.Std.bind_tc_fail] at h; exact absurd h (by simp) + | .div => rw [hni] at h; simp only [Aeneas.Std.bind_tc_div] at h; exact absurd h (by simp) + have h_inv_stage : + hacspec_ml_kem.parameters.createi K + (hacspec_ml_kem.matrix.compute_vector_u.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256 K) P + = .ok ⟨(List.range K.val).map (fun k => resGet (g_inv k)), + by simp [List.length_map, List.length_range]⟩ := by + apply createi_stage_eq _ _ g_inv _ hg_inv_ok + intro k hk + show hacspec_ml_kem.matrix.compute_vector_u.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + P ⟨BitVec.ofNat _ k⟩ = (do let a ← g_inv k; .ok (a, P)) + unfold hacspec_ml_kem.matrix.compute_vector_u.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + unfold hacspec_ml_kem.matrix.compute_vector_u.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256.call + have hkv : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := ofNat_toNat_eq k hk + have h_idx_P : Std.Array.index_usize P (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (P.val[k]!) := by + have hPlen : P.length = K.val := by + rw [hP_def]; show ((List.range K.val).map _).length = K.val + simp [List.length_map, List.length_range] + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq P (⟨BitVec.ofNat _ k⟩ : Std.Usize) + (by rw [hkv, hPlen]; exact hk) + rw [hkv] at this; exact this + rw [hg_inv_def] + show (do let a ← Std.Array.index_usize P (⟨BitVec.ofNat _ k⟩ : Std.Usize) + hacspec_ml_kem.invert_ntt.ntt_inverse a) >>= (fun a => .ok (a, P)) + = (do let a ← hacspec_ml_kem.invert_ntt.ntt_inverse (P.val[k]!); .ok (a, P)) + rw [h_idx_P]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_inv_stage]; simp only [Aeneas.Std.bind_tc_ok] + set PI : Std.Array Poly256L K := + ⟨(List.range K.val).map (fun k => resGet (g_inv k)), + by simp [List.length_map, List.length_range]⟩ with hPI_def + have hPI_at : ∀ i : Nat, i < K.val → PI.val[i]! = resGet (g_inv i) := by + intro i hi; rw [hPI_def]; exact range_map_lane K.val _ i hi + -- Stage 3: add_vectors PI evec = .ok W. + set g_add : Nat → Result Poly256L := + fun k => hacspec_ml_kem.matrix.add_polynomials (PI.val[k]!) (evec.val[k]!) with hg_add_def + have hg_add_ok : ∀ k : Nat, k < K.val → g_add k = .ok (W.val[k]!) := by + intro k hk + show hacspec_ml_kem.matrix.add_polynomials (PI.val[k]!) (evec.val[k]!) = .ok (W.val[k]!) + -- PI[k] = invₖ; combine the chain. + obtain ⟨p, hp⟩ := hprod_ok k hk + have hgp : g_prod k = .ok p := by + show hacspec_ml_kem.matrix.multiply_matrix_by_column_at T rvec (⟨BitVec.ofNat _ k⟩ : Std.Usize) = .ok p + rw [mmbc_at_eq_multiply_vectorsL T rvec ⟨BitVec.ofNat _ k⟩ + (by rw [ofNat_toNat_eq k hk]; exact hk), hcolT k hk]; exact hp + have h := hrow k hk + rw [hp] at h; simp only [Aeneas.Std.bind_tc_ok] at h + -- ntt_inverse p = .ok q for some q. + obtain ⟨q, hq⟩ : ∃ q, hacspec_ml_kem.invert_ntt.ntt_inverse p = .ok q := by + match hni : hacspec_ml_kem.invert_ntt.ntt_inverse p with + | .ok q => exact ⟨q, rfl⟩ + | .fail e => rw [hni] at h; simp only [Aeneas.Std.bind_tc_fail] at h; exact absurd h (by simp) + | .div => rw [hni] at h; simp only [Aeneas.Std.bind_tc_div] at h; exact absurd h (by simp) + have hPIk : PI.val[k]! = q := by + rw [hPI_at k hk] + show resGet (g_inv k) = q + show resGet (hacspec_ml_kem.invert_ntt.ntt_inverse (P.val[k]!)) = q + rw [hP_at k hk] + show resGet (hacspec_ml_kem.invert_ntt.ntt_inverse (resGet (g_prod k))) = q + rw [hgp] + show resGet (hacspec_ml_kem.invert_ntt.ntt_inverse p) = q + rw [hq]; rfl + rw [hPIk] + rw [hq] at h; simp only [Aeneas.Std.bind_tc_ok] at h + exact h + have hcall_add : ∀ k : Nat, k < K.val → + (hacspec_ml_kem.matrix.add_vectors.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256 K).FnMutInst.call_mut + (PI, evec) ⟨BitVec.ofNat _ k⟩ = (do let a ← g_add k; .ok (a, (PI, evec))) := by + intro k hk + show hacspec_ml_kem.matrix.add_vectors.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + (PI, evec) ⟨BitVec.ofNat _ k⟩ = (do let a ← g_add k; .ok (a, (PI, evec))) + unfold hacspec_ml_kem.matrix.add_vectors.closure.Insts.CoreOpsFunctionFnMutTupleUsizeArrayFieldElement256.call_mut + unfold hacspec_ml_kem.matrix.add_vectors.closure.Insts.CoreOpsFunctionFnTupleUsizeArrayFieldElement256.call + have hkv : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := ofNat_toNat_eq k hk + have hPIlen : PI.length = K.val := by + rw [hPI_def]; show ((List.range K.val).map _).length = K.val + simp [List.length_map, List.length_range] + have hElen : evec.length = K.val := Std.Array.length_eq evec + have h_idx_PI : Std.Array.index_usize PI (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (PI.val[k]!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq PI (⟨BitVec.ofNat _ k⟩ : Std.Usize) + (by rw [hkv, hPIlen]; exact hk) + rw [hkv] at this; exact this + have h_idx_E : Std.Array.index_usize evec (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (evec.val[k]!) := by + have := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq evec (⟨BitVec.ofNat _ k⟩ : Std.Usize) + (by rw [hkv, hElen]; exact hk) + rw [hkv] at this; exact this + show (do let a2 ← Std.Array.index_usize PI (⟨BitVec.ofNat _ k⟩ : Std.Usize) + let a3 ← Std.Array.index_usize evec (⟨BitVec.ofNat _ k⟩ : Std.Usize) + hacspec_ml_kem.matrix.add_polynomials a2 a3) >>= (fun a => .ok (a, (PI, evec))) + = (do let a ← hacspec_ml_kem.matrix.add_polynomials (PI.val[k]!) (evec.val[k]!) + .ok (a, (PI, evec))) + rw [h_idx_PI]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_E]; simp only [Aeneas.Std.bind_tc_ok] + have h_add_stage : + hacspec_ml_kem.matrix.add_vectors PI evec = .ok W := by + unfold hacspec_ml_kem.matrix.add_vectors + rw [createi_stage_eq _ _ g_add hcall_add (fun k hk => ⟨W.val[k]!, hg_add_ok k hk⟩)] + apply congrArg Result.ok + apply Subtype.ext + show (List.range K.val).map (fun k => resGet (g_add k)) = W.val + have hWval_len : W.val.length = K.val := by + have := Std.Array.length_eq W; exact this + apply List.ext_getElem + · rw [List.length_map, List.length_range, hWval_len] + · intro k hk1 _ + have hk : k < K.val := by + have : k < ((List.range K.val).map _).length := hk1 + simpa [List.length_map, List.length_range] using this + rw [List.getElem_map, List.getElem_range] + rw [show resGet (g_add k) = resGet (.ok (W.val[k]!)) from by rw [hg_add_ok k hk]] + show W.val[k]! = W.val[k] + rw [getElem!_pos W.val k (by rw [hWval_len]; exact hk)] + exact h_add_stage + +/-- Lane of `lift_vec_slice v K` at `i < K`. -/ +private theorem lift_vec_slice_lane + (v : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (K : Std.Usize) (i : Nat) (hi : i < K.val) : + (lift_vec_slice v K).val[i]! = lift_poly v.val[i]! := by + show (Std.Array.make K ((List.range K.val).map (fun i => lift_poly v.val[i]!)) (by simp)).val[i]! + = lift_poly v.val[i]! + show ((List.range K.val).map (fun i => lift_poly v.val[i]!))[i]! = lift_poly v.val[i]! + exact range_map_lane K.val _ i hi + +end PartA + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do + +set_option maxHeartbeats 3200000 in +/-- **L7.2 MAIN theorem.** The top-level `matrix.compute_vector_u` glue. + PRE: `hK : K.val ≤ 4`, `h_K_pos : 1 ≤ K.val`, slice lengths, and per-lane + bounds on `r_as_ntt` (≤ 3328) / `error_1` (≤ 29439). Mirrors L7.4's + `compute_message_fc` glue + the impl `matrix.compute_vector_u` body. -/ +@[spec] +theorem compute_vector_u_fc + {Hasher : Type} (K : Std.Usize) + (hash_functionsHashInst : libcrux_iot_ml_kem.hash_functions.Hash Hasher) + (matrix_entry : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (seed : Slice Std.U8) + (r_as_ntt error_1 result : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (cache : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (accumulator : Std.Array Std.I32 256#usize) + (hK : K.val ≤ 4) + (h_K_pos : 1 ≤ K.val) + (h_seed_len : seed.length = 32) + (h_r_len : r_as_ntt.length = K.val) + (h_err_len : error_1.length = K.val) + (h_result_len : result.length = K.val) + (h_cache_len : cache.length = K.val) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_err_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((error_1.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_vector_u + K (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + matrix_entry seed r_as_ntt error_1 result scratch cache accumulator + ⦃ ⇓ p => ⌜ hacspec_ml_kem.matrix.compute_vector_u + (lift_matrix_from_seed seed K) + (lift_vec_slice r_as_ntt K) + (lift_vec_slice error_1 K) + = .ok (lift_vec_slice p.2.1 K) ⌝ ⦄ := by + set lm : Std.Array + (Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) K := + lift_matrix_from_seed seed K with hlm_def + -- massert lengths. + have h_len_r : core.slice.Slice.len r_as_ntt = .ok K := by + unfold core.slice.Slice.len + apply congrArg Result.ok + apply Std.UScalar.eq_of_val_eq + show r_as_ntt.val.length = K.val + exact h_r_len + have h_len_e : core.slice.Slice.len error_1 = .ok K := by + unfold core.slice.Slice.len + apply congrArg Result.ok + apply Std.UScalar.eq_of_val_eq + show error_1.val.length = K.val + exact h_err_len + -- Step 0: i2 := classify 0 = 0#i32; acc1 := repeat 256 0. + set i2 : Std.I32 := 0#i32 with hi2_def + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify (0#i32 : Std.I32) = .ok i2 := by + rw [hi2_def]; rfl + set acc1 : Std.Array Std.I32 256#usize := + Std.Array.repeat (256#usize : Std.Usize) i2 with h_acc1_def + have h_acc1_zero : ∀ n : Nat, n < 256 → (acc1.val[n]!).val = 0 := by + intro n hn + rw [h_acc1_def, Std.Array.repeat_val] + rw [getElem!_pos _ n (by rw [List.length_replicate]; exact hn)] + rw [List.getElem_replicate]; rfl + have h_acc1_natAbs : ∀ n : Nat, n < 256 → (acc1.val[n]!).val.natAbs = 0 := by + intro n hn; rw [h_acc1_zero n hn]; rfl + -- r_arr : Array Poly K from r_as_ntt. + set r_arr : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K := + ⟨r_as_ntt.val, by rw [← h_r_len]⟩ with h_r_arr_def + have h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]! := by + intro c hc; rfl + -- Acc budget for loop0: acc1[n]=0, K≤4 ⟹ K·2^25 ≤ 2^30. + have h_acc1_budget : ∀ n : Fin 256, + (acc1.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30 := by + intro n + rw [h_acc1_natAbs n.val n.isLt] + have hK4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + have : (4 : Nat) * 2^25 ≤ 2^30 := by decide + omega + -- S1: row-0 column loop. + obtain ⟨⟨me1, cache1, acc2⟩, h_loop0_eq, h_row0⟩ := triple_exists_ok_fc + (compute_vector_u_loop0_fc hash_functionsHashInst matrix_entry seed r_as_ntt cache + r_arr acc1 h_seed_len h_r_len h_cache_len h_r_arr h_r_bnd h_acc1_budget) + dsimp only at h_loop0_eq h_row0 + -- cache1 length preservation (same impl call, deterministic). + have h_cache1_len : cache1.length = K.val := by + obtain ⟨v, hv_eq, hv_len⟩ := triple_exists_ok_fc + (compute_vector_u_loop0_cache_len_fc hash_functionsHashInst matrix_entry seed r_as_ntt cache + r_arr acc1 h_seed_len h_r_len h_cache_len h_r_arr h_r_bnd h_acc1_budget) + rw [h_loop0_eq] at hv_eq + have : v = (me1, cache1, acc2) := (Result.ok.inj hv_eq).symm + rw [this] at hv_len; exact hv_len + -- row-0 acc-bridge: multiply_vectors lm[0] (lift r) = .ok (scaleZ 2285 (mont_strip (poly_reducing acc2))). + have h_bridge0 := compute_vector_u_row0_acc_bridge seed r_as_ntt r_arr cache cache1 acc1 acc2 + h_acc1_zero h_r_arr h_r_bnd h_row0 + -- Destructure row0_inv once. + obtain ⟨_h_ex, h_acc2_bnd_raw, h_cache_done, _h_cache_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Row0FillFC.row0_inv, ← List.getElem!_eq_getElem?_getD] using h_row0 + -- cache-post bridge for loop1: from row0_inv conjunct (3) + h_r_arr. + have h_cache_post : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (r_as_ntt.val[c]!) (cache1.val[c]!) := by + intro c hc + have := h_cache_done c hc + rw [h_r_arr c hc] at this; exact this + -- acc2 bound for the row-0 reducing step: ≤ 2^16*3328. + have h_acc2_bnd : ∀ n : Nat, n < 256 → (acc2.val[n]!).val.natAbs ≤ 2^16 * 3328 := by + intro n hn + have hb := h_acc2_bnd_raw n hn + rw [h_acc1_natAbs n hn] at hb + have hK4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + have h2 : (4 : Nat) * 2^25 ≤ 2^16 * 3328 := by decide + omega + set acc_slice : Slice Std.I32 := Aeneas.Std.Array.to_slice acc2 with h_acc_slice_def + have h_acc_slice_len : acc_slice.length = 256 := by + rw [h_acc_slice_def, Aeneas.Std.Array.length_to_slice]; rfl + have h_acc_slice_val : acc_slice.val = acc2.val := Aeneas.Std.Array.val_to_slice acc2 + have h_acc_slice_bnd : ∀ n : Nat, n < 256 → + (acc_slice.val[n]!).val.natAbs ≤ 2^16 * 3328 := by + intro n hn; rw [h_acc_slice_val]; exact h_acc2_bnd n hn + -- (a) index_mut result 0 → (result[0]!, result.set 0). + set pre0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + result.val[0]! with h_pre0_def + have h0lt : (0 : Nat) < K.val := h_K_pos + have h_idx_result0 : Aeneas.Std.Slice.index_usize result (0#usize : Std.Usize) = .ok pre0 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq result 0#usize + (by show (0#usize : Std.Usize).val < result.length; rw [h_result_len]; exact h0lt) + have h_imt_result0 : Aeneas.Std.Slice.index_mut_usize result (0#usize : Std.Usize) + = .ok (pre0, result.set 0#usize) := by + unfold Aeneas.Std.Slice.index_mut_usize + rw [h_idx_result0]; rfl + -- (b) reducing step. + obtain ⟨result1, h_result1_eq, h_result1_mont, h_result1_lane_bnd⟩ := + triple_exists_ok_fc + (poly_reducing_from_i32_array_fc acc_slice pre0 h_acc_slice_len h_acc_slice_bnd) + have h_result1_lift : lift_poly result1 + = Impl.mont_strip_pure (Spec.poly_reducing_from_i32_array_pure acc_slice) := by + rw [← h_result1_mont, Impl.mont_strip_lift_poly_mont_eq_lift_poly] + set rslice1 : Slice _ := result.set 0#usize result1 with h_rslice1_def + have h_rslice1_at0 : rslice1.val[0]! = result1 := by + rw [h_rslice1_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_eq result 0#usize 0 result1 + ⟨rfl, by show (0:Nat) < result.length; rw [h_result_len]; exact h0lt⟩ + have h_rslice1_len : rslice1.length = K.val := by + rw [h_rslice1_def, Aeneas.Std.Slice.set_length]; exact h_result_len + have h_pre2_eq : rslice1.val[0]! = result1 := h_rslice1_at0 + have h_idx_rslice1 : Aeneas.Std.Slice.index_usize rslice1 (0#usize : Std.Usize) = .ok result1 := by + rw [← h_rslice1_at0] + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq rslice1 0#usize + (by show (0#usize : Std.Usize).val < rslice1.length; rw [h_rslice1_len]; exact h0lt) + have h_imt_rslice1 : Aeneas.Std.Slice.index_mut_usize rslice1 (0#usize : Std.Usize) + = .ok (result1, rslice1.set 0#usize) := by + unfold Aeneas.Std.Slice.index_mut_usize + rw [h_idx_rslice1]; rfl + -- (c) invert step. + have h_result1_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((result1.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 13312 := by + intro chunk hchunk k hk + have := h_result1_lane_bnd chunk hchunk k hk; omega + obtain ⟨⟨result2, scratch1⟩, h_inv_eq, h_result2_lift, h_result2_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_montgomery_fc (K := K) result1 scratch h_result1_bnd) + dsimp only at h_inv_eq h_result2_lift h_result2_bnd + set rslice2 : Slice _ := rslice1.set 0#usize result2 with h_rslice2_def + have h_rslice2_at0 : rslice2.val[0]! = result2 := by + rw [h_rslice2_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_eq rslice1 0#usize 0 result2 + ⟨rfl, by show (0:Nat) < rslice1.length; rw [h_rslice1_len]; exact h0lt⟩ + have h_rslice2_len : rslice2.length = K.val := by + rw [h_rslice2_def, Aeneas.Std.Slice.set_length]; exact h_rslice1_len + have h_idx_rslice2 : Aeneas.Std.Slice.index_usize rslice2 (0#usize : Std.Usize) = .ok result2 := by + rw [← h_rslice2_at0] + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq rslice2 0#usize + (by show (0#usize : Std.Usize).val < rslice2.length; rw [h_rslice2_len]; exact h0lt) + have h_imt_rslice2 : Aeneas.Std.Slice.index_mut_usize rslice2 (0#usize : Std.Usize) + = .ok (result2, rslice2.set 0#usize) := by + unfold Aeneas.Std.Slice.index_mut_usize + rw [h_idx_rslice2]; rfl + -- (d) index error_1 0 → error_1[0]!. + set err0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + error_1.val[0]! with h_err0_def + have h_idx_err0 : Aeneas.Std.Slice.index_usize error_1 (0#usize : Std.Usize) = .ok err0 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq error_1 0#usize + (by show (0#usize : Std.Usize).val < error_1.length; rw [h_err_len]; exact h0lt) + -- (e) add_error_reduce step. + have h_result2_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((result2.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro chunk hchunk ℓ hℓ + have := h_result2_bnd chunk hchunk ℓ hℓ; omega + have h_err0_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((err0.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439 := + fun chunk hchunk ℓ hℓ => h_err_bnd 0 h0lt ⟨chunk, hchunk⟩ ⟨ℓ, hℓ⟩ + obtain ⟨row0poly, h_add0_eq, h_row0poly_lift⟩ := + triple_exists_ok_fc + (add_error_reduce_fc result2 err0 h_result2_self_bnd h_err0_bnd) + set s1 : Slice _ := rslice2.set 0#usize row0poly with h_s1_def + have h_s1_at0 : s1.val[0]! = row0poly := by + rw [h_s1_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_eq rslice2 0#usize 0 row0poly + ⟨rfl, by show (0:Nat) < rslice2.length; rw [h_rslice2_len]; exact h0lt⟩ + have h_s1_len : s1.length = K.val := by + rw [h_s1_def, Aeneas.Std.Slice.set_length]; exact h_rslice2_len + -- row-0 row_spec: row_spec lm r_as_ntt error_1 0 = .ok (lift_poly row0poly). + have h_row_spec0 : AllRowsFillFC.row_spec lm r_as_ntt error_1 0 = .ok (lift_poly row0poly) := by + unfold AllRowsFillFC.row_spec + have hA : hacspec_ml_kem.matrix.multiply_vectors (lm.val[0]!) (lift_vec_slice r_as_ntt K) + = .ok (scaleZ 2285 (lift_poly result1)) := by + rw [hlm_def, h_result1_lift, h_acc_slice_def] + exact h_bridge0 + rw [hA]; simp only [Aeneas.Std.bind_tc_ok] + rw [compute_vector_u_ntt_inverse_eq result1 result2 h_result2_lift.symm] + simp only [Aeneas.Std.bind_tc_ok] + rw [← h_err0_def] + exact compute_vector_u_add_eq result2 err0 row0poly h_row0poly_lift.symm + -- S2: outer rows loop [1, K). Budget: acc2 re-zeroed per row (loop1 ignores acc2 content). + obtain ⟨⟨me2, result3, scratch2, acc3⟩, h_loop1_eq, h_rows⟩ := triple_exists_ok_fc + (compute_vector_u_loop1_fc hash_functionsHashInst i2 me1 seed r_as_ntt error_1 s1 + scratch1 cache1 acc2 r_arr 1#usize hK + (by show 1 ≤ (1#usize : Std.Usize).val; rfl) + (by show (1#usize : Std.Usize).val ≤ K.val; exact h_K_pos) h_seed_len h_r_len + h_cache1_len h_s1_len h_err_len (by rw [hi2_def]; rfl) + h_r_arr h_r_bnd h_err_bnd h_cache_post) + dsimp only at h_loop1_eq h_rows + -- Destructure rows_inv: done rows [1,K) + unchanged rows + length. + obtain ⟨h_rows_done, h_rows_undone, h_result3_len⟩ := by + simpa [AllRowsFillFC.rows_inv, Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + ← List.getElem!_eq_getElem?_getD] using h_rows + -- result3[0] = s1[0] = row0poly (loop1 leaves row 0 since start = 1). + have h_result3_at0 : result3.val[0]! = row0poly := by + have := h_rows_undone 0 h0lt (Or.inl (by decide)) + rw [this, h_s1_at0] + -- W := lift_vec_slice result3 K; per-row row_spec = .ok W[r]. + set W : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + lift_vec_slice result3 K with hW_def + have hW_at : ∀ r : Nat, r < K.val → W.val[r]! = lift_poly result3.val[r]! := by + intro r hr; rw [hW_def]; exact lift_vec_slice_lane result3 K r hr + have h_row_spec_all : ∀ r : Nat, r < K.val → + AllRowsFillFC.row_spec lm r_as_ntt error_1 r = .ok (W.val[r]!) := by + intro r hr + rw [hW_at r hr] + by_cases h0 : r = 0 + · subst h0; rw [h_result3_at0]; exact h_row_spec0 + · have hr1 : (1#usize : Std.Usize).val ≤ r := by + have h1v : (1#usize : Std.Usize).val = 1 := rfl + rw [h1v]; omega + have := h_rows_done r hr1 hr + rw [hlm_def]; exact this + -- Bridge row_spec → PART A `hrow` (evec = lift_vec_slice error_1 K). + have hrow : ∀ i : Nat, i < K.val → + (do + let prod ← hacspec_ml_kem.matrix.multiply_vectors (lm.val[i]!) (lift_vec_slice r_as_ntt K) + let inv ← hacspec_ml_kem.invert_ntt.ntt_inverse prod + hacspec_ml_kem.matrix.add_polynomials inv ((lift_vec_slice error_1 K).val[i]!)) + = .ok (W.val[i]!) := by + intro i hi + have h := h_row_spec_all i hi + unfold AllRowsFillFC.row_spec at h + rw [lift_vec_slice_lane error_1 K i hi] + exact h + -- PART A: hacspec compute_vector_u lm (lift r) (lift e) = .ok W. + have h_hacspec : hacspec_ml_kem.matrix.compute_vector_u lm + (lift_vec_slice r_as_ntt K) (lift_vec_slice error_1 K) = .ok W := + compute_vector_u_hacspec_eq lm (lift_vec_slice r_as_ntt K) (lift_vec_slice error_1 K) W + (Std.Array.length_eq W) hrow + -- Reduce the impl do-block to `.ok (me2, result3, scratch2, cache1, acc3)`. + apply triple_of_ok_fc (v := (me2, result3, scratch2, cache1, acc3)) + · unfold libcrux_iot_ml_kem.matrix.compute_vector_u + rw [h_len_r]; simp only [Aeneas.Std.bind_tc_ok, Aeneas.Std.massert] + rw [h_len_e]; simp only [Aeneas.Std.bind_tc_ok] + rw [show libcrux_secrets.traits.Classify.Blanket.classify (0#i32 : Std.I32) + = Aeneas.Std.Result.ok i2 from h_classify] + simp only [Aeneas.Std.bind_tc_ok] + rw [show (Std.Array.repeat (256#usize : Std.Usize) i2) = acc1 from rfl] + rw [h_loop0_eq]; simp only [Aeneas.Std.bind_tc_ok] + show (do + let s ← Aeneas.Std.lift (Aeneas.Std.Array.to_slice acc2) + let (pre, index_mut_back) ← Aeneas.Std.Slice.index_mut_usize result 0#usize + let pre1 ← libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let result1 := index_mut_back pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Slice.index_mut_usize result1 0#usize + let (pre3, scratch1) ← libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery + K portable_ops_inst pre2 scratch + let result2 := index_mut_back1 pre3 + let (pre4, index_mut_back2) ← Aeneas.Std.Slice.index_mut_usize result2 0#usize + let pre5 ← Aeneas.Std.Slice.index_usize error_1 0#usize + let pre6 ← libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + portable_ops_inst pre4 pre5 + let s1 := index_mut_back2 pre6 + let (matrix_entry2, result3, scratch2, accumulator3) ← + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1 + K portable_ops_inst hash_functionsHashInst i2 { start := 1#usize, «end» := K } + me1 seed r_as_ntt error_1 s1 scratch1 cache1 acc2 + .ok (matrix_entry2, result3, scratch2, cache1, accumulator3)) + = .ok (me2, result3, scratch2, cache1, acc3) + rw [show Aeneas.Std.lift (Aeneas.Std.Array.to_slice acc2) = Aeneas.Std.Result.ok acc_slice + from by rw [h_acc_slice_def]; rfl] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_result0]; simp only [Aeneas.Std.bind_tc_ok] + rw [show libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + (vectortraitsOperationsInst := portable_ops_inst) acc_slice pre0 + = Aeneas.Std.Result.ok result1 from h_result1_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [show (result.set 0#usize) result1 = rslice1 from rfl, h_imt_rslice1] + simp only [Aeneas.Std.bind_tc_ok] + rw [show libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery + K (vectortraitsOperationsInst := portable_ops_inst) result1 scratch + = Aeneas.Std.Result.ok (result2, scratch1) from h_inv_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [show (rslice1.set 0#usize) result2 = rslice2 from rfl, h_imt_rslice2] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_err0]; simp only [Aeneas.Std.bind_tc_ok] + rw [show libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + (vectortraitsOperationsInst := portable_ops_inst) result2 err0 + = Aeneas.Std.Result.ok row0poly from h_add0_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [show (rslice2.set 0#usize) row0poly = s1 from rfl] + rw [h_loop1_eq]; simp only [Aeneas.Std.bind_tc_ok] + · -- Spec equation: hacspec compute_vector_u ... = .ok (lift_vec_slice result3 K). + show hacspec_ml_kem.matrix.compute_vector_u (lift_matrix_from_seed seed K) + (lift_vec_slice r_as_ntt K) (lift_vec_slice error_1 K) + = .ok (lift_vec_slice result3 K) + rw [← hlm_def, ← hW_def] + exact h_hacspec + +/-- +info: 'libcrux_iot_ml_kem.Matrix.ComputeVectorU.FC.compute_vector_u_fc' depends on axioms: [propext, + Classical.choice, + Quot.sound, + sample_matrix_entry_fc] +-/ +#guard_msgs in +#print axioms compute_vector_u_fc + +end libcrux_iot_ml_kem.Matrix.ComputeVectorU.FC \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeVectorU/Hacspec.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeVectorU/Hacspec.lean new file mode 100644 index 00000000..8813584f --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeVectorU/Hacspec.lean @@ -0,0 +1,747 @@ +/- + # `Matrix/ComputeVectorU/Hacspec.lean` — L7.2 `D''` tail bridge. + + The hacspec↔pure equational bridge for the *add* (error-injection) tail of + the `compute_vector_u` decomposition, directly analogous to the proven `D` + lemma `ComputeMessage.sub_polynomials_scaleZ_eq`. + + * **D'' (hacspec side)** — `add_polynomials (scaleZ 512 b) e + = add_error_reduce_pure b e` for canonical `b` lanes (the `·512` factor + matches the impl's fused Montgomery `·1441` correction, since + `1441 · 169 ≡ 512` in `ZMod 3329`; `Bridges.glue_1441_169`). + + The factor `512` is `#eval`-validated (PBT, 2026-05-29). Per-lane + characterization mirrors `Bridges.zmodOfFE_subtract_reduce_pure_lane`. + + Local copies of the `private` lane-access / canonicity helpers from + `ComputeMessage` / `Bridges` are re-derived here (this file imports only + FCTargets/Common/Bridges, so the `private` originals are out of scope). +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeMessage.Bridges + +set_option mvcgen.warning false +set_option linter.unusedVariables false + +namespace libcrux_iot_ml_kem.Matrix.ComputeVectorU.Hacspec +open libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open libcrux_iot_ml_kem.Spec.Pure (Canonical) +section AddPolyScaleZ + +/-! ### Local lane-access / canonicity helpers (copies of the `private` + originals in `ComputeMessage` / `Bridges`). -/ + +/-- Generic `Std.Array.make … (range m).map f` lane access (local copy). -/ +private theorem mkN_map_lane'' {α : Type} [Inhabited α] {n : Std.Usize} {m : Nat} + (f : Nat → α) (k : Nat) (hk : k < m) + (hlen : ((List.range m).map f).length = n.val) : + (Std.Array.make n ((List.range m).map f) hlen).val[k]! = f k := by + show ((List.range m).map f)[k]! = f k + have h_len : ((List.range m).map f).length = m := by simp + rw [getElem!_pos _ k (by rw [h_len]; exact hk)] + simp + +/-- `chunk_at` lane access (local copy). -/ +private theorem chunk_at_lane'' + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (k ℓ : Nat) (hℓ : ℓ < 16) : + (Spec.chunk_at p k).val[ℓ]! = p.val[16 * k + ℓ]! := by + unfold Spec.chunk_at + show ((List.range 16).map (fun j => p.val[16 * k + j]!))[ℓ]! = p.val[16 * k + ℓ]! + have h_len : ((List.range 16).map (fun j => p.val[16 * k + j]!)).length = 16 := by simp + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map, List.getElem_range] + +/-- Lane access for a 16-chunk flatten shape (local copy of + `flatten_chunk_map_lane`). -/ +private theorem flatten_chunk_map_lane'' + (H : Nat → Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (j : Nat) (hj : j < 256) + (h : ((List.range 16).map H).length = (16#usize).val) : + (Spec.flatten_chunks (Std.Array.make 16#usize ((List.range 16).map H) h)).val[j]! + = (H (j / 16)).val[j % 16]! := by + have hk : j / 16 < 16 := by omega + unfold Spec.flatten_chunks + rw [mkN_map_lane'' _ j hj] + rw [mkN_map_lane'' H (j / 16) hk] + +/-- Canonical round-trip (local copy). -/ +private theorem canon_feOfZMod'' (z : ZMod 3329) : Canonical (feOfZMod z) := by + unfold Canonical feOfZMod hacspec_ml_kem.parameters.FIELD_MODULUS + show (BitVec.ofNat 16 z.val).toNat < _ + rw [BitVec.toNat_ofNat] + have hz : z.val < 3329 := ZMod.val_lt z + have : z.val % 2 ^ 16 = z.val := Nat.mod_eq_of_lt (by omega) + simp only [this]; simpa using hz + +/-- Canonical round-trip (local copy of `feOfZMod_zmodOfFE_of_canon'`). -/ +private theorem feOfZMod_zmodOfFE_of_canon'' + (fe : hacspec_ml_kem.parameters.FieldElement) (h : Canonical fe) : + feOfZMod (zmodOfFE fe) = fe := by + have h' : fe.val.val < 3329 := by + unfold Canonical hacspec_ml_kem.parameters.FIELD_MODULUS at h; simpa using h + unfold feOfZMod zmodOfFE + have hzval : ((fe.val.val : ZMod 3329)).val = fe.val.val := ZMod.val_natCast_of_lt h' + rw [hzval] + have hfeval : fe.val.val < 2 ^ 16 := by + have h_p : (3329 : Nat) ≤ 2 ^ 16 := by decide + omega + have hfebv : BitVec.ofNat 16 fe.val.val = fe.val.bv := by + apply BitVec.eq_of_toNat_eq + rw [BitVec.toNat_ofNat] + show fe.val.val % 2 ^ 16 = fe.val.bv.toNat + rw [Nat.mod_eq_of_lt hfeval]; rfl + show ({ val := ⟨BitVec.ofNat 16 fe.val.val⟩ } : + hacspec_ml_kem.parameters.FieldElement) = fe + rw [hfebv] + +/-- `scaleZ c p` is canonical per lane (local copy of `canonArr_scaleZ'`). -/ +private theorem canonArr_scaleZ'' (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) : Canonical ((scaleZ c p).val[j]!) := by + unfold scaleZ + rw [mkN_map_lane'' (fun k => feOfZMod (c * zmodOfFE (p.val[k]!))) j hj _] + exact canon_feOfZMod'' _ + +/-- Two canonical 256-arrays with equal `zmodOfFE` lanes are equal (local copy + of `eq_of_zmod_lane_canon`). -/ +private theorem eq_of_zmod_lane_canon'' + (u v : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hcu : ∀ j : Nat, j < 256 → Canonical (u.val[j]!)) + (hcv : ∀ j : Nat, j < 256 → Canonical (v.val[j]!)) + (hz : ∀ j : Nat, j < 256 → zmodOfFE (u.val[j]!) = zmodOfFE (v.val[j]!)) : + u = v := by + apply Subtype.ext + apply List.ext_getElem + · rw [Aeneas.Std.Array.length_eq u, Aeneas.Std.Array.length_eq v] + · intro j hj1 _hj2 + have hj : j < 256 := by rw [Aeneas.Std.Array.length_eq u] at hj1; simpa using hj1 + have heq : u.val[j]! = v.val[j]! := by + rw [← feOfZMod_zmodOfFE_of_canon'' (u.val[j]!) (hcu j hj), + ← feOfZMod_zmodOfFE_of_canon'' (v.val[j]!) (hcv j hj), hz j hj] + have huj : u.val[j]! = u.val[j] := + getElem!_pos u.val j (by rw [Aeneas.Std.Array.length_eq u]; exact hj) + have hvj : v.val[j]! = v.val[j] := + getElem!_pos v.val j (by rw [Aeneas.Std.Array.length_eq v]; exact hj) + rw [← huj, ← hvj]; exact heq + +/-! ### The genuinely-new per-lane characterization of `add_error_reduce_pure`. -/ + +/-- Per-lane characterization of `Spec.add_error_reduce_pure`: for `j < 256` + and canonical `b[j]`, + `zmodOfFE ((add_error_reduce_pure b e)[j]) = 512 · zmodOfFE (b[j]) + zmodOfFE (e[j])`. + The impl's fused Montgomery `·1441` correction equals `·512` in `ZMod 3329` + since `1441 · 169 ≡ 512` (`glue_1441_169`). Mirrors + `Bridges.zmodOfFE_subtract_reduce_pure_lane`. -/ +private theorem zmodOfFE_add_error_reduce_pure_lane + (b e : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) + (hb : libcrux_iot_ml_kem.Spec.Pure.Canonical (b.val[j]!)) : + zmodOfFE ((Spec.add_error_reduce_pure b e).val[j]!) + = 512 * zmodOfFE (b.val[j]!) + zmodOfFE (e.val[j]!) := by + have hℓ : j % 16 < 16 := Nat.mod_lt _ (by decide) + have hjeq : 16 * (j / 16) + j % 16 = j := by omega + unfold Spec.add_error_reduce_pure + rw [flatten_chunk_map_lane'' (fun k => Spec.chunk_add_error_reduce_pure + (Spec.chunk_at b k) (Spec.chunk_at e k)) j hj (by simp)] + unfold Spec.chunk_add_error_reduce_pure + rw [mkN_map_lane'' _ (j % 16) hℓ] + -- lane = add_pure (mul_pure (chunk_at b k)[ℓ] (lift_fe_mont 1441)) (chunk_at e k)[ℓ] + rw [chunk_at_lane'' b (j / 16) (j % 16) hℓ, chunk_at_lane'' e (j / 16) (j % 16) hℓ] + rw [hjeq] + rw [zmodOfFE_add_pure] + rw [zmodOfFE_mul_pure] + rw [zmodOfFE_lift_fe_mont] + have h1441 : (((1441#i16 : Std.I16).val : ZMod 3329)) = 1441 := by decide + rw [h1441] + have h512 : (1441 : ZMod 3329) * 169 = 512 := glue_1441_169 + rw [show (zmodOfFE (b.val[j]!) * (1441 * 169) : ZMod 3329) + = 512 * zmodOfFE (b.val[j]!) by rw [h512]; ring] + +/-! ### D'' — the `add_polynomials ∘ scaleZ` bridge. + + Mirrors `ComputeMessage.sub_polynomials_scaleZ_eq` exactly: the createi + reduction `matrix_add_polynomials_eq_ok` (public, FCTargets) gives the + per-lane `add_pure` array; `eq_of_zmod_lane_canon''` reduces equality to + canonicity + per-lane `zmodOfFE`. -/ +set_option maxHeartbeats 1000000 in +theorem add_polynomials_scaleZ_eq + (b e : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hb : ∀ j : Nat, j < 256 → + libcrux_iot_ml_kem.Spec.Pure.Canonical (b.val[j]!)) : + hacspec_ml_kem.matrix.add_polynomials (scaleZ 512 b) e + = .ok (Spec.add_error_reduce_pure b e) := by + have hc : ∀ k : Nat, k < 256 → Canonical ((scaleZ 512 b).val[k]!) := + fun k hk => canonArr_scaleZ'' 512 b k hk + rw [Stage4MatrixAddFC.matrix_add_polynomials_eq_ok (scaleZ 512 b) e] + -- The reduced LHS array (set L); show it equals `add_error_reduce_pure b e`. + set L : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + ⟨(List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((scaleZ 512 b).val[k]!) (e.val[k]!)), + by simp [List.length_map, List.length_range]⟩ with hL_def + have hL_lane : ∀ j : Nat, j < 256 → + L.val[j]! = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((scaleZ 512 b).val[j]!) (e.val[j]!) := by + intro j hj + show ((List.range 256).map (fun k => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((scaleZ 512 b).val[k]!) (e.val[k]!)))[j]! = _ + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + apply congrArg Result.ok + apply eq_of_zmod_lane_canon'' + · -- L lanes canonical + intro j hj + rw [hL_lane j hj] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + · -- add_error_reduce_pure lanes canonical + intro j hj + have hℓ : j % 16 < 16 := Nat.mod_lt _ (by decide) + have hjeq : 16 * (j / 16) + j % 16 = j := by omega + unfold Spec.add_error_reduce_pure + rw [flatten_chunk_map_lane'' (fun k => Spec.chunk_add_error_reduce_pure + (Spec.chunk_at b k) (Spec.chunk_at e k)) j hj (by simp)] + unfold Spec.chunk_add_error_reduce_pure + rw [mkN_map_lane'' _ (j % 16) hℓ] + exact libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure _ _ + · -- per-lane zmodOfFE equality + intro j hj + rw [hL_lane j hj] + rw [zmodOfFE_add_pure] + rw [scaleZ_lane 512 b j hj] + rw [zmodOfFE_add_error_reduce_pure_lane b e j hj (hb j hj)] + +end AddPolyScaleZ + +/-! ## L7.2-P0.2 — matrix-column product = vector product over the column. + + `multiply_matrix_by_column_at m vec i` (loop body folds + `add_polynomials result (multiply_ntts m[j][i] vec[j])`) is structurally + identical to `multiply_vectors (extractCol m i) vec` (loop body folds + `add_polynomials result (multiply_ntts (extractCol m i)[j] vec[j])`) since + `(extractCol m i)[j] = m[j][i]`. We reduce BOTH loops, via two + `loop_range_spec_usize` applications with the SAME invariant, to a common + per-lane fold `mcol_result_at_step`, then combine. + + `multiply_ntts` / `add_polynomials` are treated opaquely: only the index + fact `(extractCol m i)[j]! = m[j]![i]!` is needed to align the bodies. -/ +section MatrixColEqVectors + +open hacspec_ml_kem.parameters (FieldElement) + +/-- Polynomial as a 256-lane field-element array (avoids repeating the + `256#usize` literal in nested binder types, which re-triggers the + `n#usize` macro tactic — SKILL §9.13 / nested-index trap). -/ +private abbrev Poly256 := Std.Array FieldElement 256#usize + +/-- Local copy of the `private triple_of_ok_fc`. -/ +private theorem triple_of_ok_fc' {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +/-- Local copy of the `private triple_exists_ok_fc`. -/ +private theorem triple_exists_ok_fc' {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- Column `i` extracted from matrix `m`: lane `j` is `m[j][i]`. -/ +private noncomputable def extractCol {K : Std.Usize} + (m : Std.Array (Std.Array (Poly256) K) K) + (i : Std.Usize) : Std.Array (Poly256) K := + Std.Array.make K ((List.range K.val).map (fun j => (m.val[j]!).val[i.val]!)) + (by simp [List.length_map, List.length_range]) + +/-- Lane access for `extractCol`: `(extractCol m i)[j]! = m[j]![i]!` for `j < K`. -/ +private theorem extractCol_lane {K : Std.Usize} + (m : Std.Array (Std.Array (Poly256) K) K) + (i : Std.Usize) (j : Nat) (hj : j < K.val) : + (extractCol m i).val[j]! = (m.val[j]!).val[i.val]! := by + unfold extractCol + show ((List.range K.val).map (fun j => (m.val[j]!).val[i.val]!))[j]! = _ + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + simp [List.getElem_map, List.getElem_range] + +/-- Per-lane partial sum produced by the matrix-column / vector loop at step `k`: + the `add_polynomials`-product foldl, lane `ℓ`, seeded at `⟨0#u16⟩`. + Folds the raw `multiply_ntts_pure` lane of `col[c]` against `vec[c]`. -/ +private noncomputable def mcol_lane_at_step {K : Std.Usize} + (col vec : Std.Array (Poly256) K) + (k : Nat) (ℓ : Nat) : FieldElement := + (List.range k).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.multiply_ntts_pure (col.val[c]!) (vec.val[c]!)).val[ℓ]!)) + ({ val := 0#u16 } : FieldElement) + +/-- The per-step accumulator array: lane `ℓ` is `mcol_lane_at_step ... k ℓ`. -/ +private noncomputable def mcol_result_at_step {K : Std.Usize} + (col vec : Std.Array (Poly256) K) + (k : Nat) : Poly256 := + ⟨(List.range 256).map (fun ℓ => mcol_lane_at_step col vec k ℓ), + by simp [List.length_map, List.length_range]⟩ + +private theorem mcol_result_at_step_val_lane {K : Std.Usize} + (col vec : Std.Array (Poly256) K) + (k : Nat) (ℓ : Nat) (hℓ : ℓ < 256) : + (mcol_result_at_step col vec k).val[ℓ]! = mcol_lane_at_step col vec k ℓ := by + unfold mcol_result_at_step + show ((List.range 256).map (fun ℓ' => mcol_lane_at_step col vec k ℓ'))[ℓ]! = _ + rw [getElem!_pos _ ℓ (by simp [List.length_map, List.length_range, hℓ])] + rw [List.getElem_map, List.getElem_range] + +private theorem mcol_lane_at_step_zero {K : Std.Usize} + (col vec : Std.Array (Poly256) K) (ℓ : Nat) : + mcol_lane_at_step col vec 0 ℓ = ({ val := 0#u16 } : FieldElement) := by + unfold mcol_lane_at_step + rw [List.range_zero, List.foldl_nil] + +private theorem mcol_lane_at_step_succ {K : Std.Usize} + (col vec : Std.Array (Poly256) K) (k : Nat) (ℓ : Nat) : + mcol_lane_at_step col vec (k + 1) ℓ + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (mcol_lane_at_step col vec k ℓ) + ((Spec.multiply_ntts_pure (col.val[k]!) (vec.val[k]!)).val[ℓ]!) := by + unfold mcol_lane_at_step + rw [List.range_succ, List.foldl_append, List.foldl_cons, List.foldl_nil] + +/-- The accumulator-advance fact shared by both loop bodies: one column step + `add_polynomials acc (multiply_ntts col[k] vec[k])` evaluates to + `.ok (mcol_result_at_step col vec (k+1))` when `acc = mcol_result_at_step col vec k`. -/ +private theorem mcol_step_add_eq {K : Std.Usize} + (col vec : Std.Array (Poly256) K) (k : Nat) : + hacspec_ml_kem.matrix.add_polynomials (mcol_result_at_step col vec k) + (Spec.multiply_ntts_pure (col.val[k]!) (vec.val[k]!)) + = .ok (mcol_result_at_step col vec (k + 1)) := by + rw [Stage4MatrixAddFC.matrix_add_polynomials_eq_ok] + apply congrArg Result.ok + apply Subtype.ext + show (List.range 256).map (fun n => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (mcol_result_at_step col vec k).val[n]! + (Spec.multiply_ntts_pure (col.val[k]!) (vec.val[k]!)).val[n]!) + = (mcol_result_at_step col vec (k + 1)).val + unfold mcol_result_at_step + apply List.map_congr_left + intro n hn_mem + have hn_lt : n < 256 := List.mem_range.mp hn_mem + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (mcol_result_at_step col vec k).val[n]! + (Spec.multiply_ntts_pure (col.val[k]!) (vec.val[k]!)).val[n]! + = mcol_lane_at_step col vec (k + 1) n + rw [mcol_result_at_step_val_lane _ _ _ _ hn_lt, mcol_lane_at_step_succ] + +/-- The shared step-body `multiply_ntts` reduction. -/ +private theorem mcol_mult_eq (a1 a2 : Poly256) : + hacspec_ml_kem.ntt.multiply_ntts a1 a2 = .ok (Spec.multiply_ntts_pure a1 a2) := by + unfold Spec.multiply_ntts_pure + rw [HelpersFC.multiply_ntts_eq_pure_array] + +set_option maxHeartbeats 16000000 in +set_option maxRecDepth 1000 in +/-- **Triple B.** `multiply_vectors col vec` reduces to `mcol_result_at_step col vec K`. + Direct specialization of the `multiply_vectors_eq` pattern, but on raw + FE-array-typed `col`/`vec` (no `lift_vec`). -/ +private theorem multiply_vectors_eq_mcol {K : Std.Usize} + (col vec : Std.Array (Poly256) K) : + hacspec_ml_kem.matrix.multiply_vectors col vec + = .ok (mcol_result_at_step col vec K.val) := by + unfold hacspec_ml_kem.matrix.multiply_vectors + unfold hacspec_ml_kem.parameters.FieldElement.new + simp only [bind_tc_ok] + have h_triple : ⦃ ⌜ True ⌝ ⦄ + hacspec_ml_kem.matrix.multiply_vectors_loop + ({ start := 0#usize, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + col vec + (Std.Array.repeat (256#usize : Std.Usize) ({ val := 0#u16 } : FieldElement)) + ⦃ ⇓ r => ⌜ r = mcol_result_at_step col vec K.val ⌝ ⦄ := by + unfold hacspec_ml_kem.matrix.multiply_vectors_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Poly256 => + hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec p.1 p.2) + (β := Poly256) + (Std.Array.repeat (256#usize : Std.Usize) ({ val := 0#u16 } : FieldElement)) + 0#usize K + (fun k result => pure (result = mcol_result_at_step col vec k.val)) + (Nat.zero_le _) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + apply Subtype.ext + rw [Std.Array.repeat_val] + unfold mcol_result_at_step + show List.replicate 256 _ = (List.range 256).map _ + apply List.ext_getElem + · rw [List.length_replicate, List.length_map, List.length_range] + intro n h_n_lhs _ + have h_n_lt : n < 256 := by + rw [List.length_replicate] at h_n_lhs; exact h_n_lhs + rw [List.getElem_replicate, List.getElem_map, List.getElem_range] + show _ = mcol_lane_at_step col vec 0 n + rw [mcol_lane_at_step_zero]) + ?_) + · rw [PostCond.entails_noThrow] + intro r hh + have h_eq : (pure (r = mcol_result_at_step col vec K.val) : Result Prop).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_eq + · intro acc k _h_ge h_le hinv + have h_acc_eq : acc = mcol_result_at_step col vec k.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hinv + subst h_acc_eq + unfold hacspec_ml_kem.matrix.multiply_vectors_loop.body + by_cases h_lt : k.val < K.val + · have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc' h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + have hlen_col : col.length = K.val := Std.Array.length_eq col + have hlen_vec : vec.length = K.val := Std.Array.length_eq vec + have h_idx_a1 : Aeneas.Std.Array.index_usize col k = .ok (col.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq col k + (by rw [hlen_col]; exact h_lt) + have h_idx_a2 : Aeneas.Std.Array.index_usize vec k = .ok (vec.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec k + (by rw [hlen_vec]; exact h_lt) + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Poly256 => + hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec p.1 p.2) + ({ start := k, «end» := K }, mcol_result_at_step col vec k.val) + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + mcol_result_at_step col vec (k.val + 1))) := by + show hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec + { start := k, «end» := K } (mcol_result_at_step col vec k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_vectors_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let a ← Aeneas.Std.Array.index_usize col k + let a1' ← Aeneas.Std.Array.index_usize vec k + let product ← hacspec_ml_kem.ntt.multiply_ntts a a1' + let result1 ← hacspec_ml_kem.matrix.add_polynomials + (mcol_result_at_step col vec k.val) product + Aeneas.Std.Result.ok (ControlFlow.cont + (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), result1))) + : Result _) = _ + rw [h_idx_a1] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a2] + simp only [Aeneas.Std.bind_tc_ok] + rw [mcol_mult_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [mcol_step_add_eq] + simp only [Aeneas.Std.bind_tc_ok] + apply triple_of_ok_fc' h_body + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + show (pure (mcol_result_at_step col vec (k.val + 1) + = mcol_result_at_step col vec s_iter.val) : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hs_iter_val] + rfl + · have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc' h_iter_none + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Poly256 => + hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec p.1 p.2) + ({ start := k, «end» := K }, mcol_result_at_step col vec k.val) + = .ok (ControlFlow.done (mcol_result_at_step col vec k.val)) := by + show hacspec_ml_kem.matrix.multiply_vectors_loop.body col vec + { start := k, «end» := K } (mcol_result_at_step col vec k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_vectors_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc' h_body + show (pure (mcol_result_at_step col vec k.val + = mcol_result_at_step col vec K.val) : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hk_eq] + rfl + obtain ⟨v, hv_eq, hv_post⟩ := triple_exists_ok_fc' h_triple + rw [hv_eq, hv_post] + +set_option maxHeartbeats 16000000 in +set_option maxRecDepth 1000 in +/-- **Triple A (matrix).** `multiply_matrix_by_column_at m vec i` reduces to + `mcol_result_at_step (extractCol m i) vec K`. Same invariant/structure as + Triple B; the only difference is the loop body's extra `index_usize m j`/ + `index_usize a i` steps, aligned via `extractCol_lane`. -/ +private theorem multiply_matrix_by_column_at_eq_mcol {K : Std.Usize} + (m : Std.Array (Std.Array (Poly256) K) K) + (vec : Std.Array (Poly256) K) (i : Std.Usize) + (hi : i.val < K.val) : + hacspec_ml_kem.matrix.multiply_matrix_by_column_at m vec i + = .ok (mcol_result_at_step (extractCol m i) vec K.val) := by + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at + unfold hacspec_ml_kem.parameters.FieldElement.new + simp only [bind_tc_ok] + have h_triple : ⦃ ⌜ True ⌝ ⦄ + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop + ({ start := 0#usize, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + m vec i + (Std.Array.repeat (256#usize : Std.Usize) ({ val := 0#u16 } : FieldElement)) + ⦃ ⇓ r => ⌜ r = mcol_result_at_step (extractCol m i) vec K.val ⌝ ⦄ := by + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Poly256 => + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i p.1 p.2) + (β := Poly256) + (Std.Array.repeat (256#usize : Std.Usize) ({ val := 0#u16 } : FieldElement)) + 0#usize K + (fun k result => pure (result = mcol_result_at_step (extractCol m i) vec k.val)) + (Nat.zero_le _) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + apply Subtype.ext + rw [Std.Array.repeat_val] + unfold mcol_result_at_step + show List.replicate 256 _ = (List.range 256).map _ + apply List.ext_getElem + · rw [List.length_replicate, List.length_map, List.length_range] + intro n h_n_lhs _ + have h_n_lt : n < 256 := by + rw [List.length_replicate] at h_n_lhs; exact h_n_lhs + rw [List.getElem_replicate, List.getElem_map, List.getElem_range] + show _ = mcol_lane_at_step (extractCol m i) vec 0 n + rw [mcol_lane_at_step_zero]) + ?_) + · rw [PostCond.entails_noThrow] + intro r hh + have h_eq : (pure (r = mcol_result_at_step (extractCol m i) vec K.val) + : Result Prop).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_eq + · intro acc k _h_ge h_le hinv + have h_acc_eq : acc = mcol_result_at_step (extractCol m i) vec k.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hinv + subst h_acc_eq + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + by_cases h_lt : k.val < K.val + · have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc' h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- index_usize m k = m[k]!; index_usize (m[k]!) i = m[k]![i]! = (extractCol m i)[k]! + have hlen_m : m.length = K.val := Std.Array.length_eq m + have hlen_vec : vec.length = K.val := Std.Array.length_eq vec + have hlen_mk : (m.val[k.val]!).length = K.val := Std.Array.length_eq (m.val[k.val]!) + have h_idx_mk : Aeneas.Std.Array.index_usize m k = .ok (m.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq m k + (by rw [hlen_m]; exact h_lt) + have h_col_eq : (extractCol m i).val[k.val]! = (m.val[k.val]!).val[i.val]! := + extractCol_lane m i k.val h_lt + have h_idx_a1 : + Aeneas.Std.Array.index_usize (m.val[k.val]!) i = .ok ((extractCol m i).val[k.val]!) := by + rw [h_col_eq] + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq (m.val[k.val]!) i + (by rw [hlen_mk]; exact hi) + have h_idx_a2 : Aeneas.Std.Array.index_usize vec k = .ok (vec.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec k + (by rw [hlen_vec]; exact h_lt) + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Poly256 => + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i p.1 p.2) + ({ start := k, «end» := K }, mcol_result_at_step (extractCol m i) vec k.val) + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + mcol_result_at_step (extractCol m i) vec (k.val + 1))) := by + show hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i + { start := k, «end» := K } + (mcol_result_at_step (extractCol m i) vec k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let a ← Aeneas.Std.Array.index_usize m k + let a1 ← Aeneas.Std.Array.index_usize a i + let a2 ← Aeneas.Std.Array.index_usize vec k + let product ← hacspec_ml_kem.ntt.multiply_ntts a1 a2 + let result1 ← hacspec_ml_kem.matrix.add_polynomials + (mcol_result_at_step (extractCol m i) vec k.val) product + Aeneas.Std.Result.ok (ControlFlow.cont + (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), result1))) + : Result _) = _ + rw [h_idx_mk] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a1] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a2] + simp only [Aeneas.Std.bind_tc_ok] + rw [mcol_mult_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [mcol_step_add_eq] + simp only [Aeneas.Std.bind_tc_ok] + apply triple_of_ok_fc' h_body + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + show (pure (mcol_result_at_step (extractCol m i) vec (k.val + 1) + = mcol_result_at_step (extractCol m i) vec s_iter.val) + : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hs_iter_val] + rfl + · have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc' h_iter_none + have h_body : + (fun p : CoreModels.core.ops.range.Range Std.Usize × + Poly256 => + hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i p.1 p.2) + ({ start := k, «end» := K }, mcol_result_at_step (extractCol m i) vec k.val) + = .ok (ControlFlow.done (mcol_result_at_step (extractCol m i) vec k.val)) := by + show hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body m vec i + { start := k, «end» := K } + (mcol_result_at_step (extractCol m i) vec k.val) = _ + unfold hacspec_ml_kem.matrix.multiply_matrix_by_column_at_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc' h_body + show (pure (mcol_result_at_step (extractCol m i) vec k.val + = mcol_result_at_step (extractCol m i) vec K.val) + : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + rw [hk_eq] + rfl + obtain ⟨v, hv_eq, hv_post⟩ := triple_exists_ok_fc' h_triple + rw [hv_eq, hv_post] + +/-- **L7.2-P0.2.** `multiply_matrix_by_column_at m vec i` equals + `multiply_vectors (extractCol m i) vec`: the matrix-column product is the + vector product over the extracted column. Both reduce to the same + `mcol_result_at_step` fold. -/ +theorem multiply_matrix_by_column_at_eq_multiply_vectors {K : Std.Usize} + (m : Std.Array (Std.Array (Poly256) K) K) + (vec : Std.Array (Poly256) K) + (i : Std.Usize) (hi : i.val < K.val) : + hacspec_ml_kem.matrix.multiply_matrix_by_column_at m vec i + = hacspec_ml_kem.matrix.multiply_vectors (extractCol m i) vec := by + rw [multiply_matrix_by_column_at_eq_mcol m vec i hi, multiply_vectors_eq_mcol] + +end MatrixColEqVectors + +end libcrux_iot_ml_kem.Matrix.ComputeVectorU.Hacspec \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeVectorU/Impl.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeVectorU/Impl.lean new file mode 100644 index 00000000..88feceac --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Matrix/ComputeVectorU/Impl.lean @@ -0,0 +1,2580 @@ +/- + # `Matrix/ComputeVectorU/Impl.lean` — L7.2 row-0 fill-loop FC. + + Loop FC for `matrix.compute_vector_u_loop0`: the row-0 + column loop of `matrix.compute_vector_u`. Iterates over `j ∈ [0, K)`; each + step SAMPLES the matrix entry `matrix_entry1 = sample_matrix_entry seed 0 j` + (rather than reading a stored slice, as L7.1 does), reads `r_as_ntt[j]`, + `cache.index_mut j`, runs `accumulating_ntt_multiply_fill_cache` to add + column j's contribution to the I32[256] accumulator AND populate + `cache[j]`, then stores the new cache chunk. + + Mirrors the L7.1 sibling `compute_As_plus_e_loop0_fc` + + `compute_As_plus_e_loop0_step_lemma_fc` + `Stage1FillCacheFC.row0_inv`. + Structural deltas vs L7.1: + + * `r_as_ntt`, `cache` are `Slice` (not `Array K`); we use + `Slice.index_usize` / `Slice.index_mut_usize` / `Slice.set`. + * the matrix entry per column `c` comes from `sample_matrix_entry`, whose + only stable characterization is the AXIOM `sample_matrix_entry_fc`: `lift_poly (sample seed 0 c) = + (lift_matrix_from_seed seed K).val[0]!.val[c]!` (ROW-major). Since the + sampled polys are NOT retained by the impl, the acc invariant cannot use + `lift_chunk_mont (matrix_entry.coefs[j])` of a stored poly (as L7.1 does + with the input slice `matrix_A`); instead conjunct (1) characterizes the + matrix factor by the canonical `Spec.chunk_at` of the axiom-pinned + `(lift_matrix_from_seed seed K).val[0]!.val[c]!`. The step lemma bridges + the impl-side mont-domain `accumulating_ntt_multiply_poly_post` (which + uses `lift_chunk_mont (matrix_entry1.coefs[j])`) to this canonical + matrix factor via the axiom + the `chunk_at_lift_poly` identity, leaving + the `r` factor in `lift_chunk_mont` exactly as L6.3c emits it. + * the loop additionally threads `matrix_entry` (the last-sampled poly). + * `cache` characterization conjunct (3) mirrors L7.1 verbatim — it depends + only on `r_as_ntt[c]`, not on the matrix entry. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply +import LibcruxIotMlKem.Matrix.Common +import LibcruxIotMlKem.Matrix.ComputeAsPlusE +import LibcruxIotMlKem.Sampling +import LibcruxIotMlKem.Serialize +import LibcruxIotMlKem.Matrix.ComputeMessage.Impl +import LibcruxIotMlKem.Matrix.ComputeMessage.Hacspec +import LibcruxIotMlKem.Matrix.ComputeVectorU.Hacspec + +namespace libcrux_iot_ml_kem.Matrix.ComputeVectorU.Impl +open libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeMessage.Bridges libcrux_iot_ml_kem.Matrix.ComputeMessage.Hacspec libcrux_iot_ml_kem.Matrix.ComputeMessage.Impl libcrux_iot_ml_kem.Matrix.ComputeVectorU.Hacspec +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Sampling libcrux_iot_ml_kem.Serialize libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +set_option mvcgen.warning false +set_option linter.unusedVariables false + +/-- Local copy of FCTargets' `private triple_exists_ok_fc`. -/ +private theorem triple_exists_ok_fc {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- Local copy of FCTargets' `private triple_of_ok_fc`. -/ +private theorem triple_of_ok_fc {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +/-- Local re-derivation of FCTargets' `private chunk_at_lift_poly_fc`: `Spec.chunk_at (lift_poly re) k = lift_chunk re.coefs[k]`. -/ +private theorem chunk_at_lift_poly_local + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Nat) (hk : k < 16) : + Spec.chunk_at (lift_poly re) k = lift_chunk (re.coefficients.val[k]!) := by + unfold Spec.chunk_at lift_poly lift_chunk + apply Subtype.ext + have h_chunk_len : (re.coefficients.val[k]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + show (List.range 16).map + (fun j => ((List.range 256).map + (fun j' => lift_fe (re.coefficients.val[j' / 16]!).elements.val[j' % 16]!))[16 * k + j]!) + = (re.coefficients.val[k]!).elements.val.map lift_fe + apply List.ext_getElem + · simp + · intro i hi1 _hi2 + have hi : i < 16 := by + have : i < ((List.range 16).map _).length := hi1 + simpa using this + have h_idx_lt : 16 * k + i < 256 := by + have hk' : k ≤ 15 := by omega + have : 16 * k ≤ 16 * 15 := Nat.mul_le_mul_left _ hk' + omega + have h_list_len : ((List.range 256).map (fun j => + lift_fe ((re.coefficients.val[j / 16]!).elements.val[j % 16]!))).length = 256 := by + simp + rw [List.getElem_map, List.getElem_map, List.getElem_range] + rw [getElem!_pos _ (16 * k + i) (by rw [h_list_len]; exact h_idx_lt)] + rw [List.getElem_map, List.getElem_range] + have h_div : (16 * k + i) / 16 = k := by omega + have h_mod : (16 * k + i) % 16 = i := by omega + rw [h_div, h_mod] + congr 1 + rw [getElem!_pos _ i (by rw [h_chunk_len]; exact hi)] + +/-! ## §L7.2-loop0 — row-0 column-loop scaffolding (namespace `Row0FillFC`). + + Mirrors `Stage1FillCacheFC`. The matrix factor in conjunct (1) + is the canonical `Spec.chunk_at (lm.val[c]!) j` of the axiom-pinned + row-0 matrix row `lm = (lift_matrix_from_seed seed K).val[0]!`, NOT a + `lift_chunk_mont` of a stored poly. -/ + +namespace Row0FillFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +abbrev Acc := UseCacheFC.Acc +abbrev Poly := UseCacheFC.Poly + +/-- 5-conjunct invariant for the row-0 column loop of `compute_vector_u`, + in the RESOLVED all-mont/existential form. + + `lm0` is the row-0 matrix row `(lift_matrix_from_seed seed K).val[0]!` + (a `K`-array of `FieldElement 256` polys). Because the impl SAMPLES the + matrix entry each column and then DISCARDS the sampled poly (only the + buffer + r-side cache survive), the accumulator characterization cannot + reference the canonical `Spec.chunk_at (lm0[c]) j` (that is off by a + factor `2285`: `chunk_at (lift_poly p) j = 2285 · lift_chunk_mont p[j]`). + Instead we existentially quantify over the ACTUAL sampled polys + `mp : Array Poly K`, tie them to the canonical matrix row via the axiom + (`lift_poly (mp[c]) = lm0[c]`), and characterize the accumulator in the + SAME all-mont form L7.1 uses (`lift_chunk_mont (mp[c].coefs[j])`). + + Tracks: + (∃ mp) for `c < k`: `lift_poly (mp[c]) = lm0[c]` ∧ per-lane bound 3328. + (1) accumulator: for each (chunk j, lane ℓ), `mont_reduce_pure (lift_fe_int + acc[16j+ℓ].val)` equals init plus the all-mont sum of column + contributions `ntt_multiply_pure_no_acc (lift_chunk_mont mp[c].coefs[j]) + (lift_chunk_mont r[c].coefs[j]) zetas` from columns `[0, k)`. + (2) accumulator bound: `|acc[n]| ≤ |acc_init[n]| + k · 2^25`. + (3) cache populated for `c < k`: `accumulating_ntt_multiply_poly_cache_post + r_arr[c] cache[c]`. + (4) cache unchanged for `c ∈ [k, K)`: `cache[c] = cache_init[c]`. -/ +def row0_inv {K : Std.Usize} + (lm0 : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (cache_init : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (acc_init : Acc) : + Std.Usize → Acc → + Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) → + Result Prop := + fun k acc cache => pure ( + (∃ mp : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K, + (∀ c : Nat, c < k.val → + lift_poly (mp.val[c]!) = lm0.val[c]! + ∧ (∀ a : Fin 16, ∀ b : Fin 16, + ((mp.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328)) + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range k.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)))) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + k.val * 2^25) + ∧ (∀ c : Nat, c < k.val → + accumulating_ntt_multiply_poly_cache_post + (r_arr.val[c]!) (cache.val[c]!)) + ∧ (∀ c : Nat, k.val ≤ c → c < K.val → + cache.val[c]! = cache_init.val[c]!)) + +/-- Step-post for `loop_range_spec_usize` over `(matrix_entry, cache, acc)`. -/ +def row0_step_post {K : Std.Usize} + (lm0 : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (cache_init : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (acc_init : Acc) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + (Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) × + Acc) + ((libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + (Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) × + Acc)) : + Prop := + match r with + | .cont (iter', _matrix_entry', cache', acc') => + k.val < K.val ∧ iter'.«end» = K + ∧ iter'.start.val = k.val + 1 + ∧ (row0_inv lm0 r_arr cache_init acc_init iter'.start acc' cache').holds + ∧ cache'.length = K.val + | .done y => (row0_inv lm0 r_arr cache_init acc_init K y.2.2 y.2.1).holds + ∧ y.2.1.length = K.val + +end Row0FillFC + +-- Memory hygiene (rule 1 / SKILL §5.7 Idiom 2). Mirrors `L7_1a_irreducible` +-- — heavy POST predicates and the per-column forward dep are +-- made locally irreducible across the step lemma + outer Triple so that +-- elaboration does not whnf-explode through the 4-conjunct `row0_inv` body or +-- the nested accumulator characterization. -- we do NOT mark +-- `Row0FillFC.row0_inv` / `row0_step_post` irreducible — keeping them reducible +-- preserves the `simpa`-based destructure of `h_inv`. +section L7_2a_irreducible +attribute [local irreducible] accumulating_ntt_multiply_poly_cache_post +attribute [local irreducible] accumulating_ntt_multiply_poly_post +attribute [local irreducible] Spec.ntt_multiply_pure_no_acc +attribute [local irreducible] Spec.mont_reduce_pure + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the row-0 SAMPLED column loop of + `compute_vector_u`. Given the `row0_inv` invariant at step k and the + strengthened PRE bounds, executing one body iteration of + `matrix.compute_vector_u_loop0.body` produces the `row0_step_post` + (either `.cont` advancing the invariant to k+1 or `.done` capping at K). + + Mirrors `compute_As_plus_e_loop0_step_lemma_fc` with two + structural deltas: + 1. The matrix entry is SAMPLED via `sample_matrix_entry` (whose only stable + characterization is the axiom `sample_matrix_entry_fc`), not read from a + stored slice. We `triple_exists_ok_fc` the axiom at `(i, j) = (0, k)`, + obtaining `me1` with `lift_poly me1 = lm0[k]` and per-lane bounds, and + store `me1` in the invariant's existential witness `mp.set k me1`. Hence + the column-k accumulator term matches the all-mont form VERBATIM (no + `chunk_at` bridge needed). + 2. `r_as_ntt`, `cache` are `Slice` — use `Slice.index_usize` / + `Slice.index_mut_usize` / `Slice.set`. The carried `r_arr : Array Poly K` + is bridged to `r_as_ntt[k]` via `h_r_arr`. -/ +private theorem compute_vector_u_loop0_step_lemma_fc + {K : Std.Usize} {Hasher : Type} + (hash_functionsHashInst : libcrux_iot_ml_kem.hash_functions.Hash Hasher) + (matrix_entry0 : Row0FillFC.Poly) + (seed : Slice Std.U8) + (r_as_ntt cache_init : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Row0FillFC.Acc) + (h_seed_len : seed.length = 32) + (h_r_len : r_as_ntt.length = K.val) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, + (acc_init.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) + (matrix_entry : Row0FillFC.Poly) + (acc : Row0FillFC.Acc) + (cache : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (k : Std.Usize) (h_le : k.val ≤ K.val) + (h_cache_len : cache.length = K.val) + (h_inv : (Row0FillFC.row0_inv (lift_matrix_from_seed seed K).val[0]! r_arr cache_init acc_init + k acc cache).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_vector_u_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst seed r_as_ntt + { start := k, «end» := K } matrix_entry cache acc + ⦃ ⇓ r => ⌜ Row0FillFC.row0_step_post (lift_matrix_from_seed seed K).val[0]! r_arr cache_init + acc_init k r ⌝ ⦄ := by + set lm0 : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + (lift_matrix_from_seed seed K).val[0]! with hlm0_def + have h_acc_len : acc.length = 256 := Std.Array.length_eq acc + have h_acc_init_len : acc_init.length = 256 := Std.Array.length_eq acc_init + -- Destructure the 4-conjunct invariant (the first is the ∃-witness pack). + obtain ⟨⟨mp, h_mp_agree, h_inv_acc⟩, h_inv_acc_bnd, h_inv_cache_done, h_inv_cache_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop0.body + by_cases h_lt : k.val < K.val + · -- `Some k` branch. + have hK_pos : 0 < K.val := Nat.lt_of_le_of_lt (Nat.zero_le _) h_lt + -- (1) IteratorRange.next reduces to .ok (some k, { start := s_iter, end := K }). + have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, + ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- (2) Sample the matrix entry at (i, j) = (0, k) via the axiom. + have h_0K : (0#usize : Std.Usize).val < K.val := by + show (0 : Nat) < K.val; omega + obtain ⟨me1, h_me_eq, h_me_lift, h_me_bnd⟩ := + triple_exists_ok_fc + (sample_matrix_entry_fc hash_functionsHashInst matrix_entry seed 0#usize k K + h_seed_len h_0K h_lt) + -- h_me_lift : lift_poly me1 = (lift_matrix_from_seed seed K).val[0].val[k] + have h_me_lift' : lift_poly me1 = lm0.val[k.val]! := by + rw [hlm0_def]; exact h_me_lift + -- (3) Slice.index_usize r_as_ntt k reduces to .ok r_as_ntt[k.val]!. + set t_r : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + r_as_ntt.val[k.val]! with ht_r_def + have h_idx_r : Aeneas.Std.Slice.index_usize r_as_ntt k = .ok t_r := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq r_as_ntt k + (by show k.val < r_as_ntt.length; rw [h_r_len]; exact h_lt) + -- (4) Slice.index_mut_usize cache k splits into (cache[k]!, cache.set k). + set t_cache : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + cache.val[k.val]! with ht_cache_def + have h_idx_cache : Aeneas.Std.Slice.index_usize cache k = .ok t_cache := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq cache k + (by show k.val < cache.length; rw [h_cache_len]; exact h_lt) + have h_imt_cache : Aeneas.Std.Slice.index_mut_usize cache k + = .ok (t_cache, cache.set k) := by + unfold Aeneas.Std.Slice.index_mut_usize + rw [h_idx_cache]; rfl + -- (5) Apply L6.3c per-column forward dep at column k. + have h_me_bnd' : ∀ a : Fin 16, ∀ b : Fin 16, + ((me1.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_me_bnd a.val a.isLt b.val b.isLt + have h_t_r_bnd : ∀ a : Fin 16, ∀ b : Fin 16, + ((t_r.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_r_bnd k.val h_lt a b + -- Current acc bound ≤ 2^30: combine inv conjunct (2) with budget PRE. + have h_acc_cur_bnd : ∀ n : Fin 256, (acc.val[n.val]!).val.natAbs ≤ 2^30 := by + intro n + have hb := h_inv_acc_bnd n.val n.isLt + have hp := h_acc_bnd n + have hk_le : k.val * 2^25 ≤ K.val * 2^25 := Nat.mul_le_mul_right _ h_le + omega + obtain ⟨p_pair, h_p_eq, h_p_bnd_rel, h_p_acc_post, h_p_cache_post⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_fill_cache_poly_fc me1 t_r t_cache acc + h_me_bnd' h_t_r_bnd h_acc_cur_bnd) + set acc1 : Row0FillFC.Acc := p_pair.1 with hacc1_def + set cache_chunk1 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + p_pair.2 with hcc1_def + -- (6) cache1 := cache.set k cache_chunk1. + set cache1 : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) := + cache.set k cache_chunk1 with hcache1_def + have h_cache1_at : cache1.val[k.val]! = cache_chunk1 := by + rw [hcache1_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_eq cache k k.val cache_chunk1 + ⟨rfl, by show k.val < cache.length; rw [h_cache_len]; exact h_lt⟩ + have h_cache1_ne : ∀ j : Nat, j ≠ k.val → + cache1.val[j]! = cache.val[j]! := by + intro j hj + rw [hcache1_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_ne cache k j cache_chunk1 (fun h => hj h.symm) + -- (7) Body equation. + have h_body : + libcrux_iot_ml_kem.matrix.compute_vector_u_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst seed r_as_ntt + { start := k, «end» := K } matrix_entry cache acc + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), me1, cache1, acc1)) := by + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let matrix_entry1 ← + libcrux_iot_ml_kem.matrix.sample_matrix_entry portable_ops_inst + hash_functionsHashInst matrix_entry seed 0#usize k + let pre ← Aeneas.Std.Slice.index_usize r_as_ntt k + let (pre1, index_mut_back) ← Aeneas.Std.Slice.index_mut_usize cache k + let (accumulator1, pre2) ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache + portable_ops_inst matrix_entry1 pre acc pre1 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + matrix_entry1, index_mut_back pre2, accumulator1))) + : Result _) = _ + rw [h_me_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_r] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_cache] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let (accumulator1, pre2) ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache + portable_ops_inst me1 t_r acc t_cache + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + me1, (cache.set k) pre2, accumulator1))) + : Result _) = _ + rw [h_p_eq] + simp only [Aeneas.Std.bind_tc_ok] + rfl + apply triple_of_ok_fc h_body + -- (8) Discharge the step_post. + show Row0FillFC.row0_step_post lm0 r_arr cache_init acc_init k + (.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), me1, cache1, acc1)) + refine ⟨h_lt, rfl, hs_iter_val, ?_, + by rw [hcache1_def, Aeneas.Std.Slice.set_length]; exact h_cache_len⟩ + -- (9) Re-establish `row0_inv` at s_iter (= k+1). + show (Row0FillFC.row0_inv lm0 r_arr cache_init acc_init s_iter acc1 cache1).holds + unfold Row0FillFC.row0_inv + -- New existential witness: mp.set k me1. + set mp1 : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K := + mp.set k me1 with hmp1_def + have h_mp_len : mp.length = K.val := Std.Array.length_eq mp + have h_mp1_at : mp1.val[k.val]! = me1 := by + rw [hmp1_def] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq mp k k.val me1 + ⟨rfl, by rw [h_mp_len]; exact h_lt⟩ + have h_mp1_ne : ∀ j : Nat, j ≠ k.val → mp1.val[j]! = mp.val[j]! := by + intro j hj + rw [hmp1_def] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne mp k j me1 (fun h => hj h.symm) + -- column-k r factor: r_arr[k] = r_as_ntt[k] = t_r. + have h_r_arr_k : r_arr.val[k.val]! = t_r := by + rw [ht_r_def]; exact h_r_arr k.val h_lt + have h_inv_pure : + (∃ mp' : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K, + (∀ c : Nat, c < s_iter.val → + lift_poly (mp'.val[c]!) = lm0.val[c]! + ∧ (∀ a : Fin 16, ∀ b : Fin 16, + ((mp'.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328)) + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = (List.range s_iter.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp'.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)))) + ∧ (∀ n : Nat, n < 256 → + (acc1.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + s_iter.val * 2^25) + ∧ (∀ c : Nat, c < s_iter.val → + accumulating_ntt_multiply_poly_cache_post + (r_arr.val[c]!) (cache1.val[c]!)) + ∧ (∀ c : Nat, s_iter.val ≤ c → c < K.val → + cache1.val[c]! = cache_init.val[c]!) := by + refine ⟨⟨mp1, ?_, ?_⟩, ?_, ?_, ?_⟩ + · -- agreement at columns [0, s_iter). + intro c hc + rw [hs_iter_val] at hc + rcases Nat.lt_succ_iff_lt_or_eq.mp hc with hc_lt | hc_eq + · -- c < k: mp1[c] = mp[c], use h_mp_agree. + have hc_ne : c ≠ k.val := by omega + rw [h_mp1_ne c hc_ne] + exact h_mp_agree c hc_lt + · -- c = k: mp1[k] = me1, use h_me_lift' + h_me_bnd. + subst hc_eq + rw [h_mp1_at] + exact ⟨h_me_lift', fun a b => h_me_bnd a.val a.isLt b.val b.isLt⟩ + · -- (a) Accumulator characterization at s_iter = k+1. + intro j hj ℓ hℓ + have h_step_acc : + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (me1.coefficients.val[j]!)) + (lift_chunk_mont (t_r.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) := by + have := h_p_acc_post + unfold accumulating_ntt_multiply_poly_post at this + exact this j hj ℓ hℓ + have h_ih := h_inv_acc j hj ℓ hℓ + rw [h_step_acc, h_ih] + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + rw [List.range_succ, List.foldl_append] + -- Generic foldl congruence: over a list whose elements are all < k, the + -- step functions using `mp` and `mp1` agree (since `mp1[c] = mp[c]` for c < k). + have h_foldl_congr : ∀ (L : List Nat) (init : hacspec_ml_kem.parameters.FieldElement), + (∀ c ∈ L, c < k.val) → + L.foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp1.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + init + = L.foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + init := by + intro L + induction L with + | nil => intro init _; rfl + | cons hd tl ih => + intro init hmem + have hhd : hd < k.val := hmem hd (List.mem_cons_self) + have htl : ∀ c ∈ tl, c < k.val := fun c hc => hmem c (List.mem_cons_of_mem hd hc) + have hhd_ne : hd ≠ k.val := by omega + simp only [List.foldl_cons] + rw [h_mp1_ne hd hhd_ne] + exact ih _ htl + rw [h_foldl_congr (List.range k.val) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + (fun c hc => by simpa using hc)] + -- Now: column-k term me1/t_r matches mp1[k]/r_arr[k]. + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((List.range k.val).foldl _ _) _ + = (List.foldl _ ((List.range k.val).foldl _ _) [k.val]) + rw [List.foldl_cons, List.foldl_nil] + rw [h_mp1_at, h_r_arr_k] + · -- (b) Bound. + intro n hn + have h_p_bnd_n := h_p_bnd_rel ⟨n, hn⟩ + have h_p_bnd_n' : (acc1.val[n]!).val.natAbs ≤ (acc.val[n]!).val.natAbs + 2^25 := + h_p_bnd_n + have h_inv_n := h_inv_acc_bnd n hn + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + have h_arith : (k.val + 1) * 2^25 = k.val * 2^25 + 2^25 := by ring + rw [h_arith] + linarith [h_p_bnd_n', h_inv_n] + · -- (c) Cache populated for [0, s_iter). + intro c hc + rw [hs_iter_val] at hc + rcases Nat.lt_succ_iff_lt_or_eq.mp hc with hc_lt | hc_eq + · have hc_ne : c ≠ k.val := by omega + rw [h_cache1_ne c hc_ne] + exact h_inv_cache_done c hc_lt + · subst hc_eq + rw [h_cache1_at, h_r_arr_k] + exact h_p_cache_post + · -- (d) Cache unchanged for [s_iter, K). + intro c hc_ge hc_lt + rw [hs_iter_val] at hc_ge + have hc_ne : c ≠ k.val := by omega + rw [h_cache1_ne c hc_ne] + have hc_ge_k : k.val ≤ c := by omega + exact h_inv_cache_undone c hc_ge_k hc_lt + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ K, done. + have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + libcrux_iot_ml_kem.matrix.compute_vector_u_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst seed r_as_ntt + { start := k, «end» := K } matrix_entry cache acc + = .ok (ControlFlow.done (matrix_entry, cache, acc)) := by + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show Row0FillFC.row0_step_post lm0 r_arr cache_init acc_init k + (.done (matrix_entry, cache, acc)) + show (Row0FillFC.row0_inv lm0 r_arr cache_init acc_init K acc cache).holds + ∧ cache.length = K.val + refine ⟨?_, h_cache_len⟩ + unfold Row0FillFC.row0_inv + show (pure _ : Result Prop).holds + have h_inv_pure : + (∃ mp' : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K, + (∀ c : Nat, c < K.val → + lift_poly (mp'.val[c]!) = lm0.val[c]! + ∧ (∀ a : Fin 16, ∀ b : Fin 16, + ((mp'.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328)) + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp'.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)))) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + K.val * 2^25) + ∧ (∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post + (r_arr.val[c]!) (cache.val[c]!)) + ∧ (∀ c : Nat, K.val ≤ c → c < K.val → + cache.val[c]! = cache_init.val[c]!) := by + refine ⟨⟨mp, ?_, ?_⟩, ?_, ?_, ?_⟩ + · intro c hc + exact h_mp_agree c (by rw [hk_eq]; exact hc) + · intro j hj ℓ hℓ + have h_eq := h_inv_acc j hj ℓ hℓ + have h_rng : (List.range k.val) = (List.range K.val) := by rw [hk_eq] + rw [h_rng] at h_eq + exact h_eq + · intro n hn + have h_b := h_inv_acc_bnd n hn + have h_arith : k.val * 2^25 = K.val * 2^25 := by rw [hk_eq] + rw [h_arith] at h_b + exact h_b + · intro c hc + exact h_inv_cache_done c (by rw [hk_eq]; exact hc) + · intro c hc_ge hc_lt; omega + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +/-- L7.2 Stage 1 — `matrix.compute_vector_u_loop0`: the row-0 SAMPLED column + loop of `compute_vector_u`. Iterates over `j ∈ [0, K)`; each step SAMPLES + the matrix entry via `sample_matrix_entry seed 0 j` (axiomatized), reads + `r_as_ntt[j]`, runs `accumulating_ntt_multiply_fill_cache` to add column j's + contribution to the I32 accumulator AND populate `cache[j]`. + + POST: the RESOLVED all-mont/existential `row0_inv` holds at k = K — i.e. + there exists a `K`-array `mp` of sampled polys with `lift_poly mp[c] = + (lift_matrix_from_seed seed K).val[0].val[c]` (axiom-pinned, ROW-major) such + that for all (j, ℓ) ∈ [0, 16)², `mont_reduce_pure (lift_fe_int acc[16j+ℓ])` + equals the K-column all-mont sum of `ntt_multiply_pure_no_acc` outputs over + `lift_chunk_mont mp[c].coefs[j]` × `lift_chunk_mont r_arr[c].coefs[j]`, + plus the per-column cache population. This is the form consumed + by the downstream `compute_vector_u` acc-bridge. + + PRE: `seed.length = 32` (axiom requirement), `r_as_ntt.length = K`, the + array/slice bridge `h_r_arr`, the standard 16×16 bound (3328) on `r_as_ntt`, + and the accumulator BUDGET `(acc[n]).val.natAbs + K·2^25 ≤ 2^30`. The matrix + bound is supplied internally by the sample axiom's POST. + + Mirrors `compute_As_plus_e_loop0_fc`. -/ +theorem compute_vector_u_loop0_fc {K : Std.Usize} {Hasher : Type} + (hash_functionsHashInst : libcrux_iot_ml_kem.hash_functions.Hash Hasher) + (matrix_entry : Row0FillFC.Poly) (seed : Slice Std.U8) + (r_as_ntt cache : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (accumulator : Row0FillFC.Acc) + (h_seed_len : seed.length = 32) + (h_r_len : r_as_ntt.length = K.val) + (h_cache_len : cache.length = K.val) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, (accumulator.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_vector_u_loop0 + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + { start := 0#usize, «end» := K } matrix_entry seed r_as_ntt cache accumulator + ⦃ ⇓ p => ⌜ (Row0FillFC.row0_inv (lift_matrix_from_seed seed K).val[0]! r_arr cache accumulator + K p.2.2 p.2.1).holds ⌝ ⦄ := by + set lm0 : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + (lift_matrix_from_seed seed K).val[0]! with hlm0_def + -- Combined invariant: `row0_inv` ∧ cache-length-preservation (the latter is + -- needed to discharge the per-iteration `Slice.index_mut_usize` bound, which + -- the `row0_inv` does not carry; `Slice.set` preserves length). + set inv2 : Std.Usize → + ((libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + (Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) × + Row0FillFC.Acc) → Result Prop := + fun k p => pure ((Row0FillFC.row0_inv lm0 r_arr cache accumulator k p.2.2 p.2.1).holds + ∧ p.2.1.length = K.val) with hinv2_def + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop0 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, p) => + libcrux_iot_ml_kem.matrix.compute_vector_u_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst seed r_as_ntt + iter1 p.1 p.2.1 p.2.2) + (β := (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + (Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) × + Row0FillFC.Acc) + (matrix_entry, cache, accumulator) + 0#usize K + inv2 + (by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0]; exact Nat.zero_le _) + (by + -- Base case at k = 0. + rw [hinv2_def] + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, h_cache_len⟩ + show (Row0FillFC.row0_inv lm0 r_arr cache accumulator 0#usize accumulator cache).holds + unfold Row0FillFC.row0_inv + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨⟨Std.Array.repeat K matrix_entry, ?_, ?_⟩, ?_, ?_, ?_⟩ + · intro c hc + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] at hc + exact absurd hc (Nat.not_lt_zero c) + · intro j hj ℓ hℓ + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] + show Spec.mont_reduce_pure _ = (List.range 0).foldl _ _ + simp [List.range_zero, List.foldl_nil] + · intro n _; have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0']; omega + · intro c hc + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] at hc + exact absurd hc (Nat.not_lt_zero c) + · intro c _ _; trivial) + ?_) + · -- Post entailment: extract the `row0_inv` part of the combined invariant at K. + rw [PostCond.entails_noThrow] + intro r hh + rw [hinv2_def] at hh + have h_pair : (Row0FillFC.row0_inv lm0 r_arr cache accumulator K r.2.2 r.2.1).holds + ∧ r.2.1.length = K.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hh + show (Row0FillFC.row0_inv lm0 r_arr cache accumulator K r.2.2 r.2.1).holds + exact h_pair.1 + · -- Step entailment. + intro p k _h_ge h_le hinv + rw [hinv2_def] at hinv + have hinv_pair : (Row0FillFC.row0_inv lm0 r_arr cache accumulator k p.2.2 p.2.1).holds + ∧ p.2.1.length = K.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hinv + obtain ⟨hinv_row0, hinv_clen⟩ := hinv_pair + have h_step := compute_vector_u_loop0_step_lemma_fc + hash_functionsHashInst matrix_entry seed r_as_ntt cache r_arr accumulator + h_seed_len h_r_len h_r_arr h_r_bnd h_acc_bnd p.1 p.2.2 p.2.1 k h_le hinv_clen + (by rw [hlm0_def] at hinv_row0; exact hinv_row0) + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', me', cache', acc'⟩ | y + · have hP : Row0FillFC.row0_step_post lm0 r_arr cache accumulator k + (.cont (iter', me', cache', acc')) := by + rw [hlm0_def] + simpa [Std.Do.SPred.down_pure] using hh + obtain ⟨h_klt, h_end, h_start, h_inv', h_clen'⟩ := by + simpa [Row0FillFC.row0_step_post] using hP + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + refine ⟨h_klt, h_end, h_start, ?_⟩ + rw [hinv2_def] + exact (by + show (pure ((Row0FillFC.row0_inv lm0 r_arr cache accumulator iter'.start acc' cache').holds + ∧ cache'.length = K.val) : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using + (⟨h_inv', h_clen'⟩ : + (Row0FillFC.row0_inv lm0 r_arr cache accumulator iter'.start acc' cache').holds + ∧ cache'.length = K.val)) + · have hP : Row0FillFC.row0_step_post lm0 r_arr cache accumulator k + (.done (y.1, y.2.1, y.2.2)) := by + rw [hlm0_def] + simpa [Std.Do.SPred.down_pure] using hh + obtain ⟨h_done_inv, h_done_clen⟩ := by + simpa [Row0FillFC.row0_step_post] using hP + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + rw [hinv2_def] + show (pure ((Row0FillFC.row0_inv lm0 r_arr cache accumulator K y.2.2 y.2.1).holds + ∧ y.2.1.length = K.val) : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using + (⟨h_done_inv, h_done_clen⟩ : + (Row0FillFC.row0_inv lm0 r_arr cache accumulator K y.2.2 y.2.1).holds + ∧ y.2.1.length = K.val) + +end L7_2a_irreducible + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow in +set_option maxHeartbeats 1600000 in +/-- **L7.2 Stage 1 cache-length companion.** `compute_vector_u_loop0` preserves + the cache length. Reuses the private per-iteration step lemma + `compute_vector_u_loop0_step_lemma_fc` (whose `row0_step_post` carries + `cache'.length = K`), threading a `row0_inv ∧ length` invariant identical to + `compute_vector_u_loop0_fc`'s `inv2`. Needed by the L7.2 main glue to + discharge the `cache.length = K` PRE of the row-i USE-CACHE loop. -/ +theorem compute_vector_u_loop0_cache_len_fc {K : Std.Usize} {Hasher : Type} + (hash_functionsHashInst : libcrux_iot_ml_kem.hash_functions.Hash Hasher) + (matrix_entry : Row0FillFC.Poly) (seed : Slice Std.U8) + (r_as_ntt cache : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (accumulator : Row0FillFC.Acc) + (h_seed_len : seed.length = 32) + (h_r_len : r_as_ntt.length = K.val) + (h_cache_len : cache.length = K.val) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, (accumulator.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_vector_u_loop0 + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + { start := 0#usize, «end» := K } matrix_entry seed r_as_ntt cache accumulator + ⦃ ⇓ p => ⌜ p.2.1.length = K.val ⌝ ⦄ := by + set lm0 : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + (lift_matrix_from_seed seed K).val[0]! with hlm0_def + set inv2 : Std.Usize → + ((libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + (Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) × + Row0FillFC.Acc) → Result Prop := + fun k p => pure ((Row0FillFC.row0_inv lm0 r_arr cache accumulator k p.2.2 p.2.1).holds + ∧ p.2.1.length = K.val) with hinv2_def + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop0 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, p) => + libcrux_iot_ml_kem.matrix.compute_vector_u_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst seed r_as_ntt + iter1 p.1 p.2.1 p.2.2) + (β := (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + (Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) × + Row0FillFC.Acc) + (matrix_entry, cache, accumulator) + 0#usize K + inv2 + (by have h0 : (0#usize : Std.Usize).val = 0 := rfl; rw [h0]; exact Nat.zero_le _) + (by + rw [hinv2_def] + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, h_cache_len⟩ + show (Row0FillFC.row0_inv lm0 r_arr cache accumulator 0#usize accumulator cache).holds + unfold Row0FillFC.row0_inv + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨⟨Std.Array.repeat K matrix_entry, ?_, ?_⟩, ?_, ?_, ?_⟩ + · intro c hc + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] at hc; exact absurd hc (Nat.not_lt_zero c) + · intro j hj ℓ hℓ + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] + show Spec.mont_reduce_pure _ = (List.range 0).foldl _ _ + simp [List.range_zero, List.foldl_nil] + · intro n _; have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0']; omega + · intro c hc + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] at hc; exact absurd hc (Nat.not_lt_zero c) + · intro c _ _; trivial) + ?_) + · rw [PostCond.entails_noThrow] + intro r hh + rw [hinv2_def] at hh + have h_pair : (Row0FillFC.row0_inv lm0 r_arr cache accumulator K r.2.2 r.2.1).holds + ∧ r.2.1.length = K.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hh + exact h_pair.2 + · intro p k _h_ge h_le hinv + rw [hinv2_def] at hinv + have hinv_pair : (Row0FillFC.row0_inv lm0 r_arr cache accumulator k p.2.2 p.2.1).holds + ∧ p.2.1.length = K.val := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hinv + obtain ⟨hinv_row0, hinv_clen⟩ := hinv_pair + have h_step := compute_vector_u_loop0_step_lemma_fc + hash_functionsHashInst matrix_entry seed r_as_ntt cache r_arr accumulator + h_seed_len h_r_len h_r_arr h_r_bnd h_acc_bnd p.1 p.2.2 p.2.1 k h_le hinv_clen + (by rw [hlm0_def] at hinv_row0; exact hinv_row0) + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', me', cache', acc'⟩ | y + · have hP : Row0FillFC.row0_step_post lm0 r_arr cache accumulator k + (.cont (iter', me', cache', acc')) := by + rw [hlm0_def]; simpa [Std.Do.SPred.down_pure] using hh + obtain ⟨h_klt, h_end, h_start, h_inv', h_clen'⟩ := by + simpa [Row0FillFC.row0_step_post] using hP + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + refine ⟨h_klt, h_end, h_start, ?_⟩ + rw [hinv2_def] + show (pure ((Row0FillFC.row0_inv lm0 r_arr cache accumulator iter'.start acc' cache').holds + ∧ cache'.length = K.val) : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using + (⟨h_inv', h_clen'⟩ : + (Row0FillFC.row0_inv lm0 r_arr cache accumulator iter'.start acc' cache').holds + ∧ cache'.length = K.val) + · have hP : Row0FillFC.row0_step_post lm0 r_arr cache accumulator k + (.done (y.1, y.2.1, y.2.2)) := by + rw [hlm0_def]; simpa [Std.Do.SPred.down_pure] using hh + obtain ⟨h_done_inv, h_done_clen⟩ := by + simpa [Row0FillFC.row0_step_post] using hP + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + rw [hinv2_def] + show (pure ((Row0FillFC.row0_inv lm0 r_arr cache accumulator K y.2.2 y.2.1).holds + ∧ y.2.1.length = K.val) : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using + (⟨h_done_inv, h_done_clen⟩ : + (Row0FillFC.row0_inv lm0 r_arr cache accumulator K y.2.2 y.2.1).holds + ∧ y.2.1.length = K.val) + +/-! ## §L7.2 — row-0 acc-bridge (REUSES L7.4 `compute_message_acc_bridge`). -/ + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do + +/-- Local abbrev for a single 256-lane field-element poly (matrix-row entry). + Factored to keep the `256#usize` literal out of statement signatures + (SKILL §7.7 macro-retrigger trap). -/ +private abbrev FEPoly := Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize + +/-- Helper: the `lift_vec` of the existential witness `mp` collapses to the + canonical matrix row `lm0`, given the per-column agreement `h_agree` + (`lift_poly mp[c] = lm0[c]` for `c < K`). Both are `Std.Array … K`; reduce + to the lists `mp.val.map lift_poly = lm0.val` by `List.ext_getElem`. -/ +private theorem lift_vec_mp_eq {K : Std.Usize} + (mp : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (lm0 : Std.Array FEPoly K) + (h_agree : ∀ c : Nat, c < K.val → lift_poly (mp.val[c]!) = lm0.val[c]!) : + lift_vec mp = lm0 := by + apply Subtype.ext + show mp.val.map lift_poly = lm0.val + have h_mp_len : mp.val.length = K.val := Std.Array.length_eq mp + have h_lm0_len : lm0.val.length = K.val := Std.Array.length_eq lm0 + apply List.ext_getElem + · rw [List.length_map, h_mp_len, h_lm0_len] + · intro i hi1 _hi2 + have hi : i < K.val := by + have : i < (mp.val.map lift_poly).length := hi1 + rw [List.length_map, h_mp_len] at this; exact this + rw [List.getElem_map] + have h_lhs : lift_poly (mp.val[i]) = lift_poly (mp.val[i]!) := by + rw [getElem!_pos mp.val i (by rw [h_mp_len]; exact hi)] + have h_rhs : lm0.val[i] = lm0.val[i]! := by + rw [getElem!_pos lm0.val i (by rw [h_lm0_len]; exact hi)] + rw [h_lhs, h_rhs]; exact h_agree i hi + +/-- Helper: `lift_vec` of the carried `r_arr` equals the `lift_vec_slice` of + the impl `Slice` `r_as_ntt`, given the per-column tie `h_r_arr` + (`r_arr[c] = r_as_ntt[c]` for `c < K`). -/ +private theorem lift_vec_r_arr_eq {K : Std.Usize} + (r_as_ntt : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) : + lift_vec r_arr = lift_vec_slice r_as_ntt K := by + apply Subtype.ext + show r_arr.val.map lift_poly = (List.range K.val).map (fun i => lift_poly r_as_ntt.val[i]!) + have h_r_len : r_arr.val.length = K.val := Std.Array.length_eq r_arr + apply List.ext_getElem + · rw [List.length_map, h_r_len, List.length_map, List.length_range] + · intro i hi1 _hi2 + have hi : i < K.val := by + have : i < (r_arr.val.map lift_poly).length := hi1 + rw [List.length_map, h_r_len] at this; exact this + rw [List.getElem_map, List.getElem_map, List.getElem_range] + have h_lhs : lift_poly (r_arr.val[i]) = lift_poly (r_arr.val[i]!) := by + rw [getElem!_pos r_arr.val i (by rw [h_r_len]; exact hi)] + rw [h_lhs, h_r_arr i hi] + +set_option maxHeartbeats 1000000 in +/-- **L7.2 row-0 acc-bridge.** Reconciles the hacspec `multiply_vectors` of the + axiom-pinned row-0 matrix row against the loop0 accumulator scaled by + `R = 2285`. A thin wrapper that REUSES L7.4 `compute_message_acc_bridge`: + the existential witness `mp` of `row0_inv` supplies the secret-as-ntt array, + `r_arr` the u-as-ntt array, and `row0_inv`'s conjuncts (1)+(2) are exactly + `S1LoopFC.loop_inv mp r_arr`'s two conjuncts. The two vector args are + rewritten via `lift_vec_mp_eq` / `lift_vec_r_arr_eq`. -/ +theorem compute_vector_u_row0_acc_bridge {K : Std.Usize} + (seed : Slice Std.U8) + (r_as_ntt : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (cache_init cache2 : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (acc_init acc2 : Std.Array Std.I32 256#usize) + (h_acc_init_zero : ∀ n : Nat, n < 256 → (acc_init.val[n]!).val = 0) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_row0 : (Row0FillFC.row0_inv (lift_matrix_from_seed seed K).val[0]! r_arr cache_init acc_init + K acc2 cache2).holds) : + hacspec_ml_kem.matrix.multiply_vectors + ((lift_matrix_from_seed seed K).val[0]!) (lift_vec_slice r_as_ntt K) + = .ok (scaleZ 2285 (Impl.mont_strip_pure + (Spec.poly_reducing_from_i32_array_pure (Aeneas.Std.Array.to_slice acc2)))) := by + set lm0 : Std.Array FEPoly K := (lift_matrix_from_seed seed K).val[0]! with hlm0_def + -- Destructure `row0_inv`'s 4 conjuncts; the first is the ∃-witness pack. + obtain ⟨⟨mp, h_mp_agree, h_inv_acc⟩, h_inv_bnd, _h_cache_done, _h_cache_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_row0 + -- `h_inv_acc` (mont foldl) and `h_inv_bnd` (bound) are exactly + -- `S1LoopFC.loop_inv mp r_arr acc_init K acc2`'s two conjuncts. + have h_char : (S1LoopFC.loop_inv mp r_arr acc_init K acc2).holds := by + show (pure + ((∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc2.val[16 * j + ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + ∧ (∀ n : Nat, n < 256 → + (acc2.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + K.val * 2^25)) + : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using + (⟨h_inv_acc, h_inv_bnd⟩ : _ ∧ _) + -- secret-side bounds from the ∃-witness `mp`'s per-lane bound (conjunct 1.2). + have h_secret_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((mp.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro k i j; exact (h_mp_agree k.val k.isLt).2 i j + -- u-side bounds from `h_r_bnd` rewritten through `h_r_arr`. + have h_u_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((r_arr.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro k i j; rw [h_r_arr k.val k.isLt]; exact h_r_bnd k.val k.isLt i j + -- Apply the L7.4 bridge on `(mp, r_arr)`. + have h_bridge := + compute_message_acc_bridge mp r_arr acc_init acc2 h_acc_init_zero h_secret_bnd h_u_bnd h_char + -- Rewrite the two vector args: `lift_vec mp = lm0`, `lift_vec r_arr = lift_vec_slice r_as_ntt K`. + have h_mp_vec : lift_vec mp = lm0 := + lift_vec_mp_eq mp lm0 (fun c hc => (h_mp_agree c hc).1) + have h_r_vec : lift_vec r_arr = lift_vec_slice r_as_ntt K := + lift_vec_r_arr_eq r_as_ntt r_arr h_r_arr + rw [h_mp_vec, h_r_vec] at h_bridge + rw [hlm0_def] + exact h_bridge + +/-! ## §L7.2-loop1-loop0 — row-i (i ≥ 1) SAMPLED column-loop scaffolding + (namespace `RowIFillFC`). + + The USE-CACHE variant of the row-0 column loop. Combines: + * the EXISTENTIAL/sample machinery of `Row0FillFC` (the matrix entry is + SAMPLED via `sample_matrix_entry seed i j`, axiomatized; the discarded + sampled polys are threaded through an existential witness `mp`), and + * the USE-CACHE structure of `Stage2UseCacheFC` (the cache is INPUT only — read + via `Slice.index_usize`, never mutated; the per-column forward dep is + `accumulating_ntt_multiply_use_cache_poly_fc`, which requires the + cache-post PRE `accumulating_ntt_multiply_poly_cache_post (r[c]) (cache[c])`). + + Mirrors `Row0FillFC` minus the two cache-state conjuncts (3)/(4), with the + matrix row pinned to `(lift_matrix_from_seed seed K).val[i.val]!` (ROW i, + not row 0). The loop carries the 2-tuple `(matrix_entry, acc)`. -/ + +namespace RowIFillFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +abbrev Acc := UseCacheFC.Acc +abbrev Poly := UseCacheFC.Poly + +/-- 2-conjunct invariant for the row-i (i ≥ 1) SAMPLED column loop of + `compute_vector_u`, in the RESOLVED all-mont/existential form. + + `lm_i` is the row-i matrix row `(lift_matrix_from_seed seed K).val[i.val]!`. + As in `Row0FillFC.row0_inv`, because the impl SAMPLES and DISCARDS the matrix + entry each column, we existentially quantify over the ACTUAL sampled polys + `mp : Array Poly K`, tie them to the canonical matrix row via the axiom + (`lift_poly (mp[c]) = lm_i[c]`), and characterize the accumulator in the + all-mont form. NO cache conjuncts (cache is read-only here). -/ +def row_i_inv {K : Std.Usize} + (lm_i : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Acc) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∃ mp : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K, + (∀ c : Nat, c < k.val → + lift_poly (mp.val[c]!) = lm_i.val[c]! + ∧ (∀ a : Fin 16, ∀ b : Fin 16, + ((mp.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328)) + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range k.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)))) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + k.val * 2^25)) + +/-- Step-post for `loop_range_spec_usize` over `(matrix_entry, acc)` + (2-tuple, no cache). -/ +def row_i_step_post {K : Std.Usize} + (lm_i : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : Acc) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + Acc) + ((libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + Acc)) : + Prop := + match r with + | .cont (iter', _matrix_entry', acc') => + k.val < K.val ∧ iter'.«end» = K + ∧ iter'.start.val = k.val + 1 + ∧ (row_i_inv lm_i r_arr acc_init iter'.start acc').holds + | .done y => (row_i_inv lm_i r_arr acc_init K y.2).holds + +end RowIFillFC + +-- Memory hygiene (rule 1 / SKILL §5.7 Idiom 2). Mirrors `L7_2a_irreducible` +-- and `L7_1b_irreducible`. we +-- do NOT mark `RowIFillFC.row_i_inv` / `row_i_step_post` irreducible. +section L7_2b_irreducible +attribute [local irreducible] accumulating_ntt_multiply_poly_post +attribute [local irreducible] accumulating_ntt_multiply_poly_cache_post +attribute [local irreducible] Spec.ntt_multiply_pure_no_acc +attribute [local irreducible] Spec.mont_reduce_pure + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the row-i (i ≥ 1) SAMPLED column loop of + `compute_vector_u`. Combines `compute_vector_u_loop0_step_lemma_fc`'s + existential/sample/foldl machinery with `compute_As_plus_e_loop1_loop0_step_lemma_fc`'s + use-cache structure: + 1. The matrix entry is SAMPLED at `(i, k)` via `sample_matrix_entry_fc` + (NOT `(0, k)`), pinned to `lm_i[k] = (lift_matrix_from_seed seed K).val[i.val].val[k]`. + 2. NO cache mutation — `cache` is read via `Slice.index_usize` only. + 3. The per-column forward dep is `accumulating_ntt_multiply_use_cache_poly_fc`, + which needs the cache-post PRE at column k: `accumulating_ntt_multiply_poly_cache_post + (r_as_ntt[k]!) (cache[k]!)` (passed through `h_cache`). -/ +private theorem compute_vector_u_loop1_loop0_step_lemma_fc + {K : Std.Usize} {Hasher : Type} + (hash_functionsHashInst : libcrux_iot_ml_kem.hash_functions.Hash Hasher) + (seed : Slice Std.U8) + (r_as_ntt cache : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init : RowIFillFC.Acc) + (i : Std.Usize) (h_i : i.val < K.val) + (h_seed_len : seed.length = 32) + (h_r_len : r_as_ntt.length = K.val) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, + (acc_init.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) + (h_cache : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (r_as_ntt.val[c]!) (cache.val[c]!)) + (matrix_entry : RowIFillFC.Poly) + (acc : RowIFillFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ K.val) + (h_cache_len : cache.length = K.val) + (h_inv : (RowIFillFC.row_i_inv (lift_matrix_from_seed seed K).val[i.val]! r_arr acc_init + k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst seed r_as_ntt + cache i { start := k, «end» := K } matrix_entry acc + ⦃ ⇓ r => ⌜ RowIFillFC.row_i_step_post (lift_matrix_from_seed seed K).val[i.val]! r_arr + acc_init k r ⌝ ⦄ := by + set lm_i : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + (lift_matrix_from_seed seed K).val[i.val]! with hlm_i_def + have h_acc_len : acc.length = 256 := Std.Array.length_eq acc + have h_acc_init_len : acc_init.length = 256 := Std.Array.length_eq acc_init + -- Destructure the 2-conjunct invariant (the first is the ∃-witness pack). + obtain ⟨⟨mp, h_mp_agree, h_inv_acc⟩, h_inv_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0.body + by_cases h_lt : k.val < K.val + · -- `Some k` branch. + have hK_pos : 0 < K.val := Nat.lt_of_le_of_lt (Nat.zero_le _) h_lt + -- (1) IteratorRange.next reduces to .ok (some k, { start := s_iter, end := K }). + have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, + ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- (2) Sample the matrix entry at (i, k) via the axiom. + obtain ⟨me1, h_me_eq, h_me_lift, h_me_bnd⟩ := + triple_exists_ok_fc + (sample_matrix_entry_fc hash_functionsHashInst matrix_entry seed i k K + h_seed_len h_i h_lt) + -- h_me_lift : lift_poly me1 = (lift_matrix_from_seed seed K).val[i].val[k] + have h_me_lift' : lift_poly me1 = lm_i.val[k.val]! := by + rw [hlm_i_def]; exact h_me_lift + -- (3) Slice.index_usize r_as_ntt k reduces to .ok r_as_ntt[k.val]!. + set t_r : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + r_as_ntt.val[k.val]! with ht_r_def + have h_idx_r : Aeneas.Std.Slice.index_usize r_as_ntt k = .ok t_r := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq r_as_ntt k + (by show k.val < r_as_ntt.length; rw [h_r_len]; exact h_lt) + -- (4) Slice.index_usize cache k reduces to .ok cache[k.val]! (READ, not mut). + set t_cache : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + cache.val[k.val]! with ht_cache_def + have h_idx_cache : Aeneas.Std.Slice.index_usize cache k = .ok t_cache := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq cache k + (by show k.val < cache.length; rw [h_cache_len]; exact h_lt) + -- (5) Apply L6.3c per-column use-cache forward dep at column k. + have h_me_bnd' : ∀ a : Fin 16, ∀ b : Fin 16, + ((me1.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_me_bnd a.val a.isLt b.val b.isLt + have h_t_r_bnd : ∀ a : Fin 16, ∀ b : Fin 16, + ((t_r.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328 := + fun a b => h_r_bnd k.val h_lt a b + -- Cache-post hypothesis at column k. + have h_cache_at_k : accumulating_ntt_multiply_poly_cache_post t_r t_cache := + h_cache k.val h_lt + -- Current acc bound ≤ 2^30: combine inv conjunct (2) with budget PRE. + have h_acc_cur_bnd : ∀ n : Fin 256, (acc.val[n.val]!).val.natAbs ≤ 2^30 := by + intro n + have hb := h_inv_acc_bnd n.val n.isLt + have hp := h_acc_bnd n + have hk_le : k.val * 2^25 ≤ K.val * 2^25 := Nat.mul_le_mul_right _ h_le + omega + obtain ⟨acc1, h_acc1_eq, h_acc1_bnd_rel, h_acc1_post⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_use_cache_poly_fc me1 t_r t_cache acc + h_me_bnd' h_t_r_bnd h_acc_cur_bnd h_cache_at_k) + -- (6) Body equation. + have h_body : + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst seed r_as_ntt + cache i { start := k, «end» := K } matrix_entry acc + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), me1, acc1)) := by + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let matrix_entry1 ← + libcrux_iot_ml_kem.matrix.sample_matrix_entry portable_ops_inst + hash_functionsHashInst matrix_entry seed i k + let pre ← Aeneas.Std.Slice.index_usize r_as_ntt k + let pre1 ← Aeneas.Std.Slice.index_usize cache k + let accumulator1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache + portable_ops_inst matrix_entry1 pre acc pre1 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + matrix_entry1, accumulator1))) + : Result _) = _ + rw [h_me_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_r] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_cache] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_acc1_eq] + rfl + apply triple_of_ok_fc h_body + -- (7) Discharge the step_post. + show RowIFillFC.row_i_step_post lm_i r_arr acc_init k + (.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), me1, acc1)) + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + -- (8) Re-establish `row_i_inv` at s_iter (= k+1). + show (RowIFillFC.row_i_inv lm_i r_arr acc_init s_iter acc1).holds + unfold RowIFillFC.row_i_inv + -- New existential witness: mp.set k me1. + set mp1 : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K := + mp.set k me1 with hmp1_def + have h_mp_len : mp.length = K.val := Std.Array.length_eq mp + have h_mp1_at : mp1.val[k.val]! = me1 := by + rw [hmp1_def] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq mp k k.val me1 + ⟨rfl, by rw [h_mp_len]; exact h_lt⟩ + have h_mp1_ne : ∀ j : Nat, j ≠ k.val → mp1.val[j]! = mp.val[j]! := by + intro j hj + rw [hmp1_def] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne mp k j me1 (fun h => hj h.symm) + -- column-k r factor: r_arr[k] = r_as_ntt[k] = t_r. + have h_r_arr_k : r_arr.val[k.val]! = t_r := by + rw [ht_r_def]; exact h_r_arr k.val h_lt + have h_inv_pure : + (∃ mp' : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K, + (∀ c : Nat, c < s_iter.val → + lift_poly (mp'.val[c]!) = lm_i.val[c]! + ∧ (∀ a : Fin 16, ∀ b : Fin 16, + ((mp'.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328)) + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = (List.range s_iter.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp'.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)))) + ∧ (∀ n : Nat, n < 256 → + (acc1.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + s_iter.val * 2^25) := by + refine ⟨⟨mp1, ?_, ?_⟩, ?_⟩ + · -- agreement at columns [0, s_iter). + intro c hc + rw [hs_iter_val] at hc + rcases Nat.lt_succ_iff_lt_or_eq.mp hc with hc_lt | hc_eq + · -- c < k: mp1[c] = mp[c], use h_mp_agree. + have hc_ne : c ≠ k.val := by omega + rw [h_mp1_ne c hc_ne] + exact h_mp_agree c hc_lt + · -- c = k: mp1[k] = me1, use h_me_lift' + h_me_bnd. + subst hc_eq + rw [h_mp1_at] + exact ⟨h_me_lift', fun a b => h_me_bnd a.val a.isLt b.val b.isLt⟩ + · -- (a) Accumulator characterization at s_iter = k+1. + intro j hj ℓ hℓ + have h_step_acc : + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (me1.coefficients.val[j]!)) + (lift_chunk_mont (t_r.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) := by + have := h_acc1_post + unfold accumulating_ntt_multiply_poly_post at this + exact this j hj ℓ hℓ + have h_ih := h_inv_acc j hj ℓ hℓ + rw [h_step_acc, h_ih] + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + rw [List.range_succ, List.foldl_append] + -- Generic foldl congruence: over a list whose elements are all < k, the + -- step functions using `mp` and `mp1` agree (since `mp1[c] = mp[c]` for c < k). + have h_foldl_congr : ∀ (L : List Nat) (init : hacspec_ml_kem.parameters.FieldElement), + (∀ c ∈ L, c < k.val) → + L.foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp1.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + init + = L.foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + init := by + intro L + induction L with + | nil => intro init _; rfl + | cons hd tl ih => + intro init hmem + have hhd : hd < k.val := hmem hd (List.mem_cons_self) + have htl : ∀ c ∈ tl, c < k.val := fun c hc => hmem c (List.mem_cons_of_mem hd hc) + have hhd_ne : hd ≠ k.val := by omega + simp only [List.foldl_cons] + rw [h_mp1_ne hd hhd_ne] + exact ih _ htl + rw [h_foldl_congr (List.range k.val) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + (fun c hc => by simpa using hc)] + -- Now: column-k term me1/t_r matches mp1[k]/r_arr[k]. + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((List.range k.val).foldl _ _) _ + = (List.foldl _ ((List.range k.val).foldl _ _) [k.val]) + rw [List.foldl_cons, List.foldl_nil] + rw [h_mp1_at, h_r_arr_k] + · -- (b) Bound. + intro n hn + have h_acc1_bnd_n := h_acc1_bnd_rel ⟨n, hn⟩ + have h_acc1_bnd_n' : (acc1.val[n]!).val.natAbs ≤ (acc.val[n]!).val.natAbs + 2^25 := + h_acc1_bnd_n + have h_inv_n := h_inv_acc_bnd n hn + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + rw [hs_iter_eq] + have h_arith : (k.val + 1) * 2^25 = k.val * 2^25 + 2^25 := by ring + rw [h_arith] + linarith [h_acc1_bnd_n, h_inv_n] + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ K, done. + have hk_ge : k.val ≥ K.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = K.val := by omega + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = ((none : Option Std.Usize), + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr hk_ge)) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst seed r_as_ntt + cache i { start := k, «end» := K } matrix_entry acc + = .ok (ControlFlow.done (matrix_entry, acc)) := by + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show RowIFillFC.row_i_step_post lm_i r_arr acc_init k + (.done (matrix_entry, acc)) + show (RowIFillFC.row_i_inv lm_i r_arr acc_init K acc).holds + unfold RowIFillFC.row_i_inv + show (pure _ : Result Prop).holds + have h_inv_pure : + (∃ mp' : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K, + (∀ c : Nat, c < K.val → + lift_poly (mp'.val[c]!) = lm_i.val[c]! + ∧ (∀ a : Fin 16, ∀ b : Fin 16, + ((mp'.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328)) + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp'.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)))) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs + ≤ (acc_init.val[n]!).val.natAbs + K.val * 2^25) := by + refine ⟨⟨mp, ?_, ?_⟩, ?_⟩ + · intro c hc + exact h_mp_agree c (by rw [hk_eq]; exact hc) + · intro j hj ℓ hℓ + have h_eq := h_inv_acc j hj ℓ hℓ + have h_rng : (List.range k.val) = (List.range K.val) := by rw [hk_eq] + rw [h_rng] at h_eq + exact h_eq + · intro n hn + have h_b := h_inv_acc_bnd n hn + have h_arith : k.val * 2^25 = K.val * 2^25 := by rw [hk_eq] + rw [h_arith] at h_b + exact h_b + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +/-- L7.2 Stage 2 — `matrix.compute_vector_u_loop1_loop0`: the row-i (i ≥ 1) + SAMPLED column loop of `compute_vector_u` (USE-CACHE variant). Iterates over + `j ∈ [0, K)`; each step SAMPLES the matrix entry via `sample_matrix_entry + seed i j` (axiomatized), reads `r_as_ntt[j]` and `cache[j]` (read-only), and + runs `accumulating_ntt_multiply_use_cache` to add column j's contribution to + the I32 accumulator. + + POST: the RESOLVED all-mont/existential `row_i_inv` holds at k = K — i.e. + there exists a `K`-array `mp` of sampled polys with `lift_poly mp[c] = + (lift_matrix_from_seed seed K).val[i].val[c]` (axiom-pinned, ROW-major) such + that for all (j, ℓ) ∈ [0, 16)², `mont_reduce_pure (lift_fe_int acc[16j+ℓ])` + equals the K-column all-mont sum of `ntt_multiply_pure_no_acc` outputs. + + PRE: `seed.length = 32`, `r_as_ntt.length = K`, `cache.length = K`, the + array/slice bridge `h_r_arr`, the 16×16 bound (3328) on `r_as_ntt`, the + accumulator BUDGET, and the cache-post hypothesis `h_cache` (the cache was + populated by Stage 1's row-0 column loop and is consumed read-only). + + Mirrors `compute_As_plus_e_loop1_loop0_fc` + the local + `compute_vector_u_loop0_fc`. -/ +theorem compute_vector_u_loop1_loop0_fc {K : Std.Usize} {Hasher : Type} + (hash_functionsHashInst : libcrux_iot_ml_kem.hash_functions.Hash Hasher) + (matrix_entry : RowIFillFC.Poly) (seed : Slice Std.U8) + (r_as_ntt cache : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (accumulator : RowIFillFC.Acc) + (i : Std.Usize) (h_i : i.val < K.val) + (h_seed_len : seed.length = 32) + (h_r_len : r_as_ntt.length = K.val) + (h_cache_len : cache.length = K.val) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, (accumulator.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30) + (h_cache : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (r_as_ntt.val[c]!) (cache.val[c]!)) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0 + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + { start := 0#usize, «end» := K } matrix_entry seed r_as_ntt cache accumulator i + ⦃ ⇓ p => ⌜ (RowIFillFC.row_i_inv (lift_matrix_from_seed seed K).val[i.val]! r_arr accumulator + K p.2).holds ⌝ ⦄ := by + set lm_i : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + (lift_matrix_from_seed seed K).val[i.val]! with hlm_i_def + -- Combined invariant: `row_i_inv` ∧ cache-length-preservation (needed to + -- discharge the per-iteration `Slice.index_usize cache` bound; the cache is + -- never mutated so its length is constant, but the `row_i_inv` + -- does not carry it). + set inv2 : Std.Usize → + ((libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + RowIFillFC.Acc) → Result Prop := + fun k p => pure ((RowIFillFC.row_i_inv lm_i r_arr accumulator k p.2).holds) with hinv2_def + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, p) => + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst seed r_as_ntt + cache i iter1 p.1 p.2) + (β := (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) × + RowIFillFC.Acc) + (matrix_entry, accumulator) + 0#usize K + inv2 + (by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0]; exact Nat.zero_le _) + (by + -- Base case at k = 0. + rw [hinv2_def] + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + show (RowIFillFC.row_i_inv lm_i r_arr accumulator 0#usize accumulator).holds + unfold RowIFillFC.row_i_inv + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨⟨Std.Array.repeat K matrix_entry, ?_, ?_⟩, ?_⟩ + · intro c hc + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] at hc + exact absurd hc (Nat.not_lt_zero c) + · intro j hj ℓ hℓ + have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0'] + show Spec.mont_reduce_pure _ = (List.range 0).foldl _ _ + simp [List.range_zero, List.foldl_nil] + · intro n _; have h0' : (0#usize : Std.Usize).val = 0 := rfl + rw [h0']; omega) + ?_) + · -- Post entailment: extract the `row_i_inv` at K. + rw [PostCond.entails_noThrow] + intro r hh + rw [hinv2_def] at hh + have h_inv_holds : (RowIFillFC.row_i_inv lm_i r_arr accumulator K r.2).holds := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hh + show (RowIFillFC.row_i_inv lm_i r_arr accumulator K r.2).holds + exact h_inv_holds + · -- Step entailment. + intro p k _h_ge h_le hinv + rw [hinv2_def] at hinv + have hinv_row : (RowIFillFC.row_i_inv lm_i r_arr accumulator k p.2).holds := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using hinv + have h_step := compute_vector_u_loop1_loop0_step_lemma_fc + hash_functionsHashInst seed r_as_ntt cache r_arr accumulator i h_i + h_seed_len h_r_len h_r_arr h_r_bnd h_acc_bnd h_cache p.1 p.2 k h_le h_cache_len + (by rw [hlm_i_def] at hinv_row; exact hinv_row) + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', me', acc'⟩ | y + · have hP : RowIFillFC.row_i_step_post lm_i r_arr accumulator k + (.cont (iter', me', acc')) := by + rw [hlm_i_def] + simpa [Std.Do.SPred.down_pure] using hh + obtain ⟨h_klt, h_end, h_start, h_inv'⟩ := by + simpa [RowIFillFC.row_i_step_post] using hP + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + refine ⟨h_klt, h_end, h_start, ?_⟩ + rw [hinv2_def] + show (pure ((RowIFillFC.row_i_inv lm_i r_arr accumulator iter'.start acc').holds) + : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv' + · have hP : RowIFillFC.row_i_step_post lm_i r_arr accumulator k + (.done (y.1, y.2)) := by + rw [hlm_i_def] + simpa [Std.Do.SPred.down_pure] using hh + have h_done_inv : (RowIFillFC.row_i_inv lm_i r_arr accumulator K y.2).holds := by + simpa [RowIFillFC.row_i_step_post] using hP + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + rw [hinv2_def] + show (pure ((RowIFillFC.row_i_inv lm_i r_arr accumulator K y.2).holds) + : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_done_inv + +end L7_2b_irreducible + +/-! ## §L7.2 — finalize glue (F1/F2) + row-i acc-bridge. + + The two scalar-glue lemmas (F1/F2) mirror the L7.4 FC chain + (`FC/ComputeMessage.lean` lines 230-249) restricted to the C+B+compose + head (F1) and the D'' tail (F2). The row-i acc-bridge mirrors + `compute_vector_u_row0_acc_bridge` for the cache-free `RowIFillFC.row_i_inv` + (2 conjuncts) and matrix row `i`. -/ + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do + +/-- `scaleZ c p` lanes are `feOfZMod _`, hence canonical (local copy of the + `private canonArr_scaleZ'` in ComputeMessage/Hacspec; mirrors the + `FC/ComputeMessage.lean` private helper). -/ +private theorem scaleZ_canon (c : ZMod 3329) + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j : Nat) (hj : j < 256) : + libcrux_iot_ml_kem.Spec.Pure.Canonical ((scaleZ c p).val[j]!) := by + unfold scaleZ + show libcrux_iot_ml_kem.Spec.Pure.Canonical + (((List.range 256).map (fun k => feOfZMod (c * zmodOfFE (p.val[k]!))))[j]!) + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical feOfZMod + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] + show (BitVec.ofNat 16 ((c * zmodOfFE (p.val[j]!)).val)).toNat < 3329 + set z := c * zmodOfFE (p.val[j]!) + have h_lt16 : z.val < 2 ^ 16 := by have := ZMod.val_lt z; omega + rw [BitVec.toNat_ofNat, Nat.mod_eq_of_lt h_lt16] + exact ZMod.val_lt _ + +/-- `lift_poly x` lanes are `lift_fe _ = feOfZMod _`, hence canonical (local + copy mirroring the `FC/ComputeMessage.lean` private helper). -/ +private theorem lift_poly_canon + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (j : Nat) (hj : j < 256) : + libcrux_iot_ml_kem.Spec.Pure.Canonical ((lift_poly re).val[j]!) := by + unfold lift_poly + show libcrux_iot_ml_kem.Spec.Pure.Canonical + (((List.range 256).map (fun i => + lift_fe (re.coefficients.val[i / 16]!).elements.val[i % 16]!))[j]!) + rw [getElem!_pos _ j (by simp [List.length_map, List.length_range, hj])] + rw [List.getElem_map, List.getElem_range] + unfold lift_fe libcrux_iot_ml_kem.Spec.Pure.Canonical feOfZMod + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] + show (⟨BitVec.ofNat 16 ((i16_to_spec_fe_plain + (re.coefficients.val[j / 16]!).elements.val[j % 16]!).val)⟩ : Std.U16).val < 3329 + show (BitVec.ofNat 16 ((i16_to_spec_fe_plain + (re.coefficients.val[j / 16]!).elements.val[j % 16]!).val)).toNat < 3329 + set z := i16_to_spec_fe_plain (re.coefficients.val[j / 16]!).elements.val[j % 16]! + have h_lt16 : z.val < 2 ^ 16 := by + have := ZMod.val_lt z; omega + rw [BitVec.toNat_ofNat, Nat.mod_eq_of_lt h_lt16] + exact ZMod.val_lt _ + +/-- **L7.2 F1 (ntt_inverse glue).** Mirrors the L7.4 C+B+compose head + (`FC/ComputeMessage.lean` 230-243) minus the subtract tail: + `ntt_inverse (scaleZ 2285 result1) = .ok (scaleZ 512 result2)` given the + invert-pure tie `invert_pure result1 = result2`. -/ +theorem compute_vector_u_ntt_inverse_eq + (result1 result2 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_inv : Spec.invert_ntt_montgomery_pure (lift_poly result1) = lift_poly result2) : + hacspec_ml_kem.invert_ntt.ntt_inverse (scaleZ 2285 (lift_poly result1)) + = .ok (scaleZ 512 (lift_poly result2)) := by + -- C: ntt_inverse (scaleZ 2285 x) = .ok (scaleZ 3303 (invert_pure (scaleZ 2285 x))). + rw [ntt_inverse_eq_scaleZ_invert_pure (scaleZ 2285 (lift_poly result1)) + (fun j hj => scaleZ_canon 2285 (lift_poly result1) j hj)] + -- B: invert_pure (scaleZ 2285 x) = scaleZ 2285 (invert_pure x). + rw [invert_ntt_montgomery_pure_scaleZ 2285 (lift_poly result1) + (fun j hj => lift_poly_canon result1 j hj)] + -- compose: scaleZ 3303 (scaleZ 2285 y) = scaleZ 512 y. + rw [scaleZ_compose 3303 2285 (Spec.invert_ntt_montgomery_pure (lift_poly result1)), + glue_3303_2285] + -- tie: invert_pure result1 = result2. + rw [h_inv] + +/-- **L7.2 F2 (add glue).** Trivial wrapper of `add_polynomials_scaleZ_eq` + (D''): `add_polynomials (scaleZ 512 result2) error_i = .ok result_i` given + the add-error-reduce tie. -/ +theorem compute_vector_u_add_eq + (result2 error_i result_i : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_tail : Spec.add_error_reduce_pure (lift_poly result2) (lift_poly error_i) + = lift_poly result_i) : + hacspec_ml_kem.matrix.add_polynomials (scaleZ 512 (lift_poly result2)) (lift_poly error_i) + = .ok (lift_poly result_i) := by + rw [add_polynomials_scaleZ_eq (lift_poly result2) (lift_poly error_i) + (fun j hj => lift_poly_canon result2 j hj), h_tail] + +set_option maxHeartbeats 1000000 in +/-- **L7.2 row-i acc-bridge.** Mirror of `compute_vector_u_row0_acc_bridge` for + the cache-free `RowIFillFC.row_i_inv` (2 conjuncts) and matrix row `i`. The + `∃ mp` witness of `row_i_inv` supplies the secret-as-ntt array; `r_arr` the + u-as-ntt array; the two `row_i_inv` conjuncts are exactly + `S1LoopFC.loop_inv mp r_arr`'s two conjuncts. The two vector args are + rewritten via `lift_vec_mp_eq` / `lift_vec_r_arr_eq`. -/ +theorem compute_vector_u_rowi_acc_bridge {K : Std.Usize} + (seed : Slice Std.U8) + (r_as_ntt : Slice (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (r_arr : Std.Array (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) + (acc_init acc2 : Std.Array Std.I32 256#usize) + (i : Std.Usize) + (h_acc_init_zero : ∀ n : Nat, n < 256 → (acc_init.val[n]!).val = 0) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_rowi : (RowIFillFC.row_i_inv (lift_matrix_from_seed seed K).val[i.val]! r_arr acc_init + K acc2).holds) : + hacspec_ml_kem.matrix.multiply_vectors + ((lift_matrix_from_seed seed K).val[i.val]!) (lift_vec_slice r_as_ntt K) + = .ok (scaleZ 2285 (Impl.mont_strip_pure + (Spec.poly_reducing_from_i32_array_pure (Aeneas.Std.Array.to_slice acc2)))) := by + set lm_i : Std.Array FEPoly K := (lift_matrix_from_seed seed K).val[i.val]! with hlm_i_def + -- Destructure `row_i_inv`'s 2 conjuncts; the first is the ∃-witness pack. + obtain ⟨⟨mp, h_mp_agree, h_inv_acc⟩, h_inv_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_rowi + -- `h_inv_acc` (mont foldl) and `h_inv_bnd` (bound) are exactly + -- `S1LoopFC.loop_inv mp r_arr acc_init K acc2`'s two conjuncts. + have h_char : (S1LoopFC.loop_inv mp r_arr acc_init K acc2).holds := by + show (pure + ((∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc2.val[16 * j + ℓ]!).val) + = (List.range K.val).foldl + (fun s c => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure s + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (mp.val[c]!.coefficients.val[j]!)) + (lift_chunk_mont (r_arr.val[c]!.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val))) + ∧ (∀ n : Nat, n < 256 → + (acc2.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + K.val * 2^25)) + : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using + (⟨h_inv_acc, h_inv_bnd⟩ : _ ∧ _) + -- secret-side bounds from the ∃-witness `mp`'s per-lane bound (conjunct 1.2). + have h_secret_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((mp.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro k i j; exact (h_mp_agree k.val k.isLt).2 i j + -- u-side bounds from `h_r_bnd` rewritten through `h_r_arr`. + have h_u_bnd : ∀ k : Fin K.val, ∀ i j : Fin 16, + ((r_arr.val[k.val]!.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro k i j; rw [h_r_arr k.val k.isLt]; exact h_r_bnd k.val k.isLt i j + -- Apply the L7.4 bridge on `(mp, r_arr)`. + have h_bridge := + compute_message_acc_bridge mp r_arr acc_init acc2 h_acc_init_zero h_secret_bnd h_u_bnd h_char + -- Rewrite the two vector args: `lift_vec mp = lm_i`, `lift_vec r_arr = lift_vec_slice r_as_ntt K`. + have h_mp_vec : lift_vec mp = lm_i := + lift_vec_mp_eq mp lm_i (fun c hc => (h_mp_agree c hc).1) + have h_r_vec : lift_vec r_arr = lift_vec_slice r_as_ntt K := + lift_vec_r_arr_eq r_as_ntt r_arr h_r_arr + rw [h_mp_vec, h_r_vec] at h_bridge + rw [hlm_i_def] + exact h_bridge + +/-! ## §L7.2 Stage 3 — outer rows loop FC (`compute_vector_u_loop1`). + + The OUTER rows loop `[start, K)` of `compute_vector_u`. Each row `i1` does: + re-zero the accumulator (`Array.repeat 256 i_zero`, `i_zero = classify 0`), + run the USE-CACHE inner column loop (`compute_vector_u_loop1_loop0`), then + the per-row finalize `reducing_from_i32_array → invert_ntt_montgomery → + add_error_reduce`, storing `result[i1] := result_poly`. + + Mirrors `compute_As_plus_e_loop1_fc` structurally — the + rows loop `loop_range_spec_usize` wrapper + the re-zero-per-row pattern + + the (done-rows ∧ unchanged-rows) `rows_inv`. L7.1's per-row finalize is + reducing→add (no invert); ours is reducing→INVERT→add (the L7.4 glue WALK, + `FC/ComputeMessage.lean` 168-251, adapted to add-error instead of subtract). + + The per-row hacspec value is captured by `AllRowsFillFC.row_spec` (a `Result` + do-block: multiply_vectors → ntt_inverse → add_polynomials), and the + invariant says `row_spec lm r_as_ntt error_1 r = .ok (lift_poly result[r])` + for completed rows `r ∈ [start, k)`, with all other rows unchanged. -/ + +namespace AllRowsFillFC + +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Matrix.Common libcrux_iot_ml_kem.Matrix.ComputeAsPlusE libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Sampling libcrux_iot_ml_kem.Serialize libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt + +/-- The 256-lane I32 accumulator carried by the rows loop. -/ +abbrev Acc := Std.Array Std.I32 256#usize + +/-- The portable-vector poly type (matrix row entry / scratch). -/ +abbrev Poly := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- A single 256-lane field-element poly. -/ +abbrev FEPoly := Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize + +/-- The raw scratch vector carried by the rows loop (NOT a poly wrapper). -/ +abbrev Scratch := libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- The per-row hacspec value (as a `Result`): the matrix-row `r` times the + secret vector, ntt-inverted, plus the per-row error. -/ +noncomputable def row_spec {K : Std.Usize} + (lm : Std.Array (Std.Array FEPoly K) K) + (r_as_ntt error_1 : Slice Poly) (r : Nat) : Result FEPoly := + (do + let prod ← hacspec_ml_kem.matrix.multiply_vectors (lm.val[r]!) (lift_vec_slice r_as_ntt K) + let inv ← hacspec_ml_kem.invert_ntt.ntt_inverse prod + hacspec_ml_kem.matrix.add_polynomials inv (lift_poly (error_1.val[r]!))) + +/-- 2-conjunct invariant for the outer rows loop. Tracks: + (1) Per-completed-row characterization: for each row `r ∈ [start, k)`, + `row_spec lm r_as_ntt error_1 r = .ok (lift_poly result[r]!)`. + (2) Unchanged rows: for each `r ∈ [0, K)` with `r < start ∨ k ≤ r`, + `result[r]! = result_init[r]!`. + (3) Length preservation: `result.length = K.val` (needed to discharge the + per-iteration `Slice.index_mut result k` bound; the conjuncts (1)/(2) do not carry it). -/ +def rows_inv {K : Std.Usize} + (lm : Std.Array (Std.Array FEPoly K) K) + (r_as_ntt error_1 : Slice Poly) + (result_init : Slice Poly) (start : Std.Usize) : + Std.Usize → Slice Poly → Scratch → Acc → Result Prop := + fun k result _scratch _acc => pure ( + (∀ r : Nat, start.val ≤ r → r < k.val → + row_spec lm r_as_ntt error_1 r = .ok (lift_poly (result.val[r]!))) + ∧ (∀ r : Nat, r < K.val → (r < start.val ∨ k.val ≤ r) → + result.val[r]! = result_init.val[r]!) + ∧ result.length = K.val) + +/-- Step-post for `loop_range_spec_usize` over the loop's 4-carry + `(matrix_entry, result, scratch, accumulator)`. -/ +def rows_step_post {K : Std.Usize} + (lm : Std.Array (Std.Array FEPoly K) K) + (r_as_ntt error_1 : Slice Poly) + (result_init : Slice Poly) (start : Std.Usize) (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Poly × Slice Poly × Scratch × Acc) + (Poly × Slice Poly × Scratch × Acc)) : + Prop := + match r with + | .cont (iter', _me', result', scratch', acc') => + k.val < K.val ∧ iter'.«end» = K + ∧ iter'.start.val = k.val + 1 + ∧ (rows_inv lm r_as_ntt error_1 result_init start + iter'.start result' scratch' acc').holds + | .done y => (rows_inv lm r_as_ntt error_1 result_init start + K y.2.1 y.2.2.1 y.2.2.2).holds + +end AllRowsFillFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do + +set_option maxHeartbeats 16000000 in +/-- **L7.2 Stage 3 — per-row step lemma.** For the `Some i1` branch (row `k`) of + `compute_vector_u_loop1.body`: re-zero the accumulator, run the use-cache + inner column loop (`compute_vector_u_loop1_loop0_fc`), then the per-row + finalize WALK (`reducing_from_i32_array` → `invert_ntt_montgomery` → + `add_error_reduce`), and store `result[i1] := result_poly`. Re-establishes + `rows_inv` at `k+1`. Mirrors `compute_As_plus_e_loop1_step_lemma_fc` + the L7.4 glue finalize (`FC/ComputeMessage.lean` 168-251, + add-error instead of subtract). -/ +private theorem compute_vector_u_loop1_step_lemma_fc {K : Std.Usize} {Hasher : Type} + (hash_functionsHashInst : libcrux_iot_ml_kem.hash_functions.Hash Hasher) + (i_zero : Std.I32) + (seed : Slice Std.U8) + (r_as_ntt error_1 cache : Slice AllRowsFillFC.Poly) + (r_arr : Std.Array AllRowsFillFC.Poly K) + (result_init : Slice AllRowsFillFC.Poly) + (start : Std.Usize) + (hK : K.val ≤ 4) + (h_seed_len : seed.length = 32) (h_r_len : r_as_ntt.length = K.val) + (h_cache_len : cache.length = K.val) (h_err_len : error_1.length = K.val) + (h_i_zero : i_zero.val = 0) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_err_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((error_1.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 29439) + (h_cache : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (r_as_ntt.val[c]!) (cache.val[c]!)) + (matrix_entry : AllRowsFillFC.Poly) (result : Slice AllRowsFillFC.Poly) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (accumulator : AllRowsFillFC.Acc) + (k : Std.Usize) (h_ge : start.val ≤ k.val) (h_le : k.val ≤ K.val) + (h_inv : (AllRowsFillFC.rows_inv (lift_matrix_from_seed seed K) r_as_ntt error_1 + result_init start k result scratch accumulator).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1.body + K (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + i_zero seed r_as_ntt error_1 cache { start := k, «end» := K } matrix_entry result scratch + accumulator + ⦃ ⇓ r => ⌜ AllRowsFillFC.rows_step_post (lift_matrix_from_seed seed K) r_as_ntt error_1 + result_init start k r ⌝ ⦄ := by + set lm : Std.Array + (Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) K := + lift_matrix_from_seed seed K with hlm_def + -- Destructure the 3-conjunct invariant. + obtain ⟨h_inv_done, h_inv_undone, h_result_len⟩ := by + simpa [AllRowsFillFC.rows_inv, Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + ← List.getElem!_eq_getElem?_getD] using h_inv + have h_result_len : result.length = K.val := h_result_len + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop1.body + by_cases h_lt : k.val < K.val + · -- `Some k` branch (row i1 = k). + -- (1) IteratorRange.next reduces to (some k, {start := s_iter, end := K}). + have h_iter_step : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ ∃ s : Std.Usize, s.val = k.val + 1 ∧ + r = (some k, + ({ start := s, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_step + obtain ⟨s_iter, hs_iter_val, hv_iter_pair⟩ := hv_iter_post + -- (2) Re-zeroed accumulator: acc1 := Array.repeat 256 i_zero, all-zero. + set acc1 : AllRowsFillFC.Acc := + Aeneas.Std.Array.repeat 256#usize i_zero with h_acc1_def + have h_acc1_get : ∀ n : Nat, n < 256 → acc1.val[n]! = i_zero := by + intro n hn + rw [h_acc1_def] + show (Aeneas.Std.Array.repeat 256#usize i_zero).val[n]! = i_zero + rw [Aeneas.Std.Array.repeat_val] + rw [getElem!_pos _ n (by rw [List.length_replicate]; exact hn)] + exact List.getElem_replicate _ + have h_acc1_zero : ∀ n : Nat, n < 256 → (acc1.val[n]!).val = 0 := by + intro n hn; rw [h_acc1_get n hn]; exact h_i_zero + have h_acc1_natAbs : ∀ n : Nat, n < 256 → (acc1.val[n]!).val.natAbs = 0 := by + intro n hn; rw [h_acc1_zero n hn]; rfl + have h_acc1_bnd : ∀ n : Fin 256, + (acc1.val[n.val]!).val.natAbs + K.val * 2^25 ≤ 2^30 := by + intro n + rw [h_acc1_natAbs n.val n.isLt] + have hK4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + have : (4 : Nat) * 2^25 ≤ 2^30 := by decide + omega + -- (3) Run the use-cache inner column loop at row k with the zeroed acc. + have h_stage2 := + compute_vector_u_loop1_loop0_fc hash_functionsHashInst matrix_entry seed r_as_ntt cache + r_arr acc1 k h_lt h_seed_len h_r_len h_cache_len h_r_arr h_r_bnd h_acc1_bnd h_cache + obtain ⟨me_acc, h_me_acc_eq, h_rowi⟩ := triple_exists_ok_fc h_stage2 + set me1 : AllRowsFillFC.Poly := me_acc.1 with h_me1_def + set acc2 : AllRowsFillFC.Acc := me_acc.2 with h_acc2_def + -- (4) Bound on acc2 from row_i_inv conjunct 2 (acc1 zero, K ≤ 4). + have h_rowi' : (RowIFillFC.row_i_inv (lift_matrix_from_seed seed K).val[k.val]! r_arr acc1 + K acc2).holds := h_rowi + obtain ⟨_h_exists, h_acc2_bnd_raw⟩ := by + simpa [RowIFillFC.row_i_inv, Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + using h_rowi' + have h_acc2_bnd : ∀ n : Nat, n < 256 → (acc2.val[n]!).val.natAbs ≤ 2^16 * 3328 := by + intro n hn + have hb := h_acc2_bnd_raw n hn + simp only [← List.getElem!_eq_getElem?_getD] at hb + rw [h_acc1_natAbs n hn] at hb + have hK4 : K.val * 2^25 ≤ 4 * 2^25 := Nat.mul_le_mul_right _ hK + have h2 : (4 : Nat) * 2^25 ≤ 2^16 * 3328 := by decide + omega + -- (5) acc-bridge: multiply_vectors row k = .ok (scaleZ 2285 (mont_strip (poly_reducing ...))). + set acc_slice : Slice Std.I32 := Aeneas.Std.Array.to_slice acc2 with h_acc_slice_def + have h_acc_slice_len : acc_slice.length = 256 := by + rw [h_acc_slice_def, Aeneas.Std.Array.length_to_slice]; rfl + have h_acc_slice_val : acc_slice.val = acc2.val := + Aeneas.Std.Array.val_to_slice acc2 + have h_acc_slice_bnd : ∀ n : Nat, n < 256 → + (acc_slice.val[n]!).val.natAbs ≤ 2^16 * 3328 := by + intro n hn; rw [h_acc_slice_val]; exact h_acc2_bnd n hn + have h_bridge := + compute_vector_u_rowi_acc_bridge seed r_as_ntt r_arr acc1 acc2 k + h_acc1_zero h_r_arr h_r_bnd h_rowi' + -- (6) Index_mut result k → (result[k]!, result.set k). + set pre : AllRowsFillFC.Poly := result.val[k.val]! with h_pre_def + have h_idx_result : Aeneas.Std.Slice.index_usize result k = .ok pre := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq result k + (by show k.val < result.length; rw [h_result_len]; exact h_lt) + have h_imt_result : Aeneas.Std.Slice.index_mut_usize result k + = .ok (pre, result.set k) := by + unfold Aeneas.Std.Slice.index_mut_usize + rw [h_idx_result]; rfl + -- (7) reducing step: result1. + obtain ⟨result1, h_result1_eq, h_result1_mont, h_result1_lane_bnd⟩ := + triple_exists_ok_fc + (poly_reducing_from_i32_array_fc acc_slice pre h_acc_slice_len h_acc_slice_bnd) + have h_result1_lift : lift_poly result1 + = Impl.mont_strip_pure (Spec.poly_reducing_from_i32_array_pure acc_slice) := by + rw [← h_result1_mont, Impl.mont_strip_lift_poly_mont_eq_lift_poly] + -- result1 := result.set k result1 (the new slice after the reducing store). + set rslice1 : Slice AllRowsFillFC.Poly := result.set k result1 with h_rslice1_def + have h_rslice1_at : rslice1.val[k.val]! = result1 := by + rw [h_rslice1_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_eq result k k.val result1 + ⟨rfl, by show k.val < result.length; rw [h_result_len]; exact h_lt⟩ + have h_rslice1_len : rslice1.length = K.val := by + rw [h_rslice1_def] + show (result.set k result1).length = K.val + rw [Aeneas.Std.Slice.set_length]; exact h_result_len + -- (8) index_mut rslice1 k → (result1, rslice1.set k). + set pre2 : AllRowsFillFC.Poly := rslice1.val[k.val]! with h_pre2_def + have h_idx_rslice1 : Aeneas.Std.Slice.index_usize rslice1 k = .ok pre2 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq rslice1 k + (by show k.val < rslice1.length; rw [h_rslice1_len]; exact h_lt) + have h_imt_rslice1 : Aeneas.Std.Slice.index_mut_usize rslice1 k + = .ok (pre2, rslice1.set k) := by + unfold Aeneas.Std.Slice.index_mut_usize + rw [h_idx_rslice1]; rfl + have h_pre2_eq : pre2 = result1 := h_pre2_def.trans h_rslice1_at + -- (9) invert step: result2. + have h_result1_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((result1.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 13312 := by + intro chunk hchunk ℓ hℓ + have := h_result1_lane_bnd chunk hchunk ℓ hℓ; omega + obtain ⟨⟨result2, scratch1⟩, h_inv_eq, h_result2_lift, h_result2_bnd⟩ := + triple_exists_ok_fc + (invert_ntt_montgomery_fc (K := K) result1 scratch h_result1_bnd) + dsimp only at h_inv_eq h_result2_lift h_result2_bnd + -- result2 := rslice1.set k result2. + set rslice2 : Slice AllRowsFillFC.Poly := rslice1.set k result2 with h_rslice2_def + have h_rslice2_at : rslice2.val[k.val]! = result2 := by + rw [h_rslice2_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_eq rslice1 k k.val result2 + ⟨rfl, by show k.val < rslice1.length; rw [h_rslice1_len]; exact h_lt⟩ + have h_rslice2_len : rslice2.length = K.val := by + rw [h_rslice2_def] + show (rslice1.set k result2).length = K.val + rw [Aeneas.Std.Slice.set_length]; exact h_rslice1_len + -- (10) index_mut rslice2 k → (result2, rslice2.set k). + set pre4 : AllRowsFillFC.Poly := rslice2.val[k.val]! with h_pre4_def + have h_idx_rslice2 : Aeneas.Std.Slice.index_usize rslice2 k = .ok pre4 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq rslice2 k + (by show k.val < rslice2.length; rw [h_rslice2_len]; exact h_lt) + have h_imt_rslice2 : Aeneas.Std.Slice.index_mut_usize rslice2 k + = .ok (pre4, rslice2.set k) := by + unfold Aeneas.Std.Slice.index_mut_usize + rw [h_idx_rslice2]; rfl + have h_pre4_eq : pre4 = result2 := h_pre4_def.trans h_rslice2_at + -- (11) index error_1 k → error_1[k]!. + set err_k : AllRowsFillFC.Poly := error_1.val[k.val]! with h_err_k_def + have h_idx_err : Aeneas.Std.Slice.index_usize error_1 k = .ok err_k := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq error_1 k + (by show k.val < error_1.length; rw [h_err_len]; exact h_lt) + -- (12) add_error_reduce step: result_poly. + have h_result2_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((result2.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro chunk hchunk ℓ hℓ + have := h_result2_bnd chunk hchunk ℓ hℓ; omega + have h_err_k_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((err_k.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439 := + fun chunk hchunk ℓ hℓ => + h_err_bnd k.val h_lt ⟨chunk, hchunk⟩ ⟨ℓ, hℓ⟩ + obtain ⟨result_poly, h_add_eq, h_result_poly_lift⟩ := + triple_exists_ok_fc + (add_error_reduce_fc result2 err_k h_result2_self_bnd h_err_k_bnd) + -- result_poly slice: rnew := rslice2.set k result_poly. + set rnew : Slice AllRowsFillFC.Poly := rslice2.set k result_poly with h_rnew_def + have h_rnew_at : rnew.val[k.val]! = result_poly := by + rw [h_rnew_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_eq rslice2 k k.val result_poly + ⟨rfl, by show k.val < rslice2.length; rw [h_rslice2_len]; exact h_lt⟩ + have h_rnew_ne : ∀ j : Nat, j ≠ k.val → rnew.val[j]! = result.val[j]! := by + intro j hj + have e1 : rnew.val[j]! = rslice2.val[j]! := by + rw [h_rnew_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_ne rslice2 k j result_poly (fun h => hj h.symm) + have e2 : rslice2.val[j]! = rslice1.val[j]! := by + rw [h_rslice2_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_ne rslice1 k j result2 (fun h => hj h.symm) + have e3 : rslice1.val[j]! = result.val[j]! := by + rw [h_rslice1_def] + simpa [Aeneas.Std.Slice.getElem!_Nat_eq] using + Aeneas.Std.Slice.getElem!_Nat_set_ne result k j result1 (fun h => hj h.symm) + rw [e1, e2, e3] + -- (13) row_spec equation: row_spec lm r_as_ntt error_1 k = .ok (lift_poly result_poly). + have h_invrel : Spec.invert_ntt_montgomery_pure (lift_poly result1) = lift_poly result2 := + h_result2_lift.symm + have h_tailrel : Spec.add_error_reduce_pure (lift_poly result2) (lift_poly err_k) + = lift_poly result_poly := h_result_poly_lift.symm + have h_row_spec : AllRowsFillFC.row_spec lm r_as_ntt error_1 k.val + = .ok (lift_poly result_poly) := by + unfold AllRowsFillFC.row_spec + -- multiply_vectors = .ok (scaleZ 2285 (lift_poly result1)). + have hA : hacspec_ml_kem.matrix.multiply_vectors (lm.val[k.val]!) (lift_vec_slice r_as_ntt K) + = .ok (scaleZ 2285 (lift_poly result1)) := by + rw [hlm_def, h_result1_lift, h_acc_slice_def] + exact h_bridge + rw [hA]; simp only [Aeneas.Std.bind_tc_ok] + -- ntt_inverse (scaleZ 2285 (lift_poly result1)) = .ok (scaleZ 512 (lift_poly result2)). + rw [compute_vector_u_ntt_inverse_eq result1 result2 h_invrel] + simp only [Aeneas.Std.bind_tc_ok] + -- add_polynomials (scaleZ 512 (lift_poly result2)) (lift_poly err_k) = .ok (lift_poly result_poly). + rw [← h_err_k_def] + exact compute_vector_u_add_eq result2 err_k result_poly h_tailrel + -- (14) Body equation: reduce loop1.body to .ok (.cont (..., me1, rnew, scratch1, acc2)). + have h_body : + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1.body + K (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + i_zero seed r_as_ntt error_1 cache { start := k, «end» := K } matrix_entry result scratch + accumulator + = .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + me1, rnew, scratch1, acc2)) := by + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop1.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_pair] at hv_iter_eq + rw [hv_iter_eq] + simp only [Aeneas.Std.bind_tc_ok] + -- Enter the `Some k` branch; inner column loop (acc folded to `acc1`). + show ((do + let (matrix_entry1, accumulator2) ← + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1_loop0 + (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + { start := 0#usize, «end» := K } matrix_entry seed r_as_ntt cache + (Aeneas.Std.Array.repeat 256#usize i_zero) k + let s ← Aeneas.Std.lift (Aeneas.Std.Array.to_slice accumulator2) + let (pre, index_mut_back) ← Aeneas.Std.Slice.index_mut_usize result k + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let result1 := index_mut_back pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Slice.index_mut_usize result1 k + let (pre3, scratch1) ← + libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery K portable_ops_inst pre2 scratch + let result2 := index_mut_back1 pre3 + let (pre4, index_mut_back2) ← Aeneas.Std.Slice.index_mut_usize result2 k + let pre5 ← Aeneas.Std.Slice.index_usize error_1 k + let pre6 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + portable_ops_inst pre4 pre5 + let s1 := index_mut_back2 pre6 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + matrix_entry1, s1, scratch1, accumulator2))) + : Result _) = _ + rw [← h_acc1_def, h_me_acc_eq] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let s := Aeneas.Std.Array.to_slice me_acc.2 + let (pre, index_mut_back) ← Aeneas.Std.Slice.index_mut_usize result k + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst s pre + let result1 := index_mut_back pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Slice.index_mut_usize result1 k + let (pre3, scratch1) ← + libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery K portable_ops_inst pre2 scratch + let result2 := index_mut_back1 pre3 + let (pre4, index_mut_back2) ← Aeneas.Std.Slice.index_mut_usize result2 k + let pre5 ← Aeneas.Std.Slice.index_usize error_1 k + let pre6 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + portable_ops_inst pre4 pre5 + let s1 := index_mut_back2 pre6 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + me_acc.1, s1, scratch1, me_acc.2))) + : Result _) = _ + rw [← h_acc2_def, ← h_acc_slice_def, h_imt_result] + simp only [Aeneas.Std.bind_tc_ok] + -- Reducing store: `index_mut_back := result.set k`. + show ((do + let pre1 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + portable_ops_inst acc_slice pre + let result1 := (result.set k) pre1 + let (pre2, index_mut_back1) ← Aeneas.Std.Slice.index_mut_usize result1 k + let (pre3, scratch1) ← + libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery K portable_ops_inst pre2 scratch + let result2 := index_mut_back1 pre3 + let (pre4, index_mut_back2) ← Aeneas.Std.Slice.index_mut_usize result2 k + let pre5 ← Aeneas.Std.Slice.index_usize error_1 k + let pre6 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + portable_ops_inst pre4 pre5 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + me_acc.1, index_mut_back2 pre6, scratch1, acc2))) + : Result _) = _ + have h_red_eq : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + (vectortraitsOperationsInst := portable_ops_inst) acc_slice pre = .ok result1 := + h_result1_eq + rw [h_red_eq] + simp only [Aeneas.Std.bind_tc_ok] + rw [show (result.set k) result1 = rslice1 from rfl, h_imt_rslice1] + simp only [Aeneas.Std.bind_tc_ok] + -- Invert store: `index_mut_back1 := rslice1.set k`, applied to result2. + show ((do + let (pre3, scratch1) ← + libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery K portable_ops_inst pre2 scratch + let result2 := (rslice1.set k) pre3 + let (pre4, index_mut_back2) ← Aeneas.Std.Slice.index_mut_usize result2 k + let pre5 ← Aeneas.Std.Slice.index_usize error_1 k + let pre6 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + portable_ops_inst pre4 pre5 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + me_acc.1, index_mut_back2 pre6, scratch1, acc2))) + : Result _) = _ + rw [show pre2 = result1 from h_pre2_eq] + have h_inv_eq' : + libcrux_iot_ml_kem.invert_ntt.invert_ntt_montgomery + K (vectortraitsOperationsInst := portable_ops_inst) result1 scratch + = .ok (result2, scratch1) := h_inv_eq + rw [h_inv_eq'] + simp only [Aeneas.Std.bind_tc_ok] + -- Reduce the `(pre3, scratch1) := (result2, scratch1)` destructuring. + show ((do + let (pre4, index_mut_back2) ← + Aeneas.Std.Slice.index_mut_usize ((rslice1.set k) result2) k + let pre5 ← Aeneas.Std.Slice.index_usize error_1 k + let pre6 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + portable_ops_inst pre4 pre5 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + me_acc.1, index_mut_back2 pre6, scratch1, acc2))) + : Result _) = _ + rw [show (rslice1.set k) result2 = rslice2 from rfl, h_imt_rslice2] + simp only [Aeneas.Std.bind_tc_ok] + -- Add-error store: `index_mut_back2 := rslice2.set k`, applied to result_poly. + show ((do + let pre5 ← Aeneas.Std.Slice.index_usize error_1 k + let pre6 ← + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + portable_ops_inst pre4 pre5 + .ok (ControlFlow.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), + me_acc.1, (rslice2.set k) pre6, scratch1, acc2))) + : Result _) = _ + rw [show pre4 = result2 from h_pre4_eq] + rw [show (Aeneas.Std.Slice.index_usize error_1 k) = .ok err_k from h_idx_err] + simp only [Aeneas.Std.bind_tc_ok] + have h_add_eq' : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + (vectortraitsOperationsInst := portable_ops_inst) result2 err_k = .ok result_poly := + h_add_eq + rw [h_add_eq'] + simp only [Aeneas.Std.bind_tc_ok] + rw [show (rslice2.set k) result_poly = rnew from rfl] + apply triple_of_ok_fc h_body + -- (15) Discharge step_post: rows_inv at s_iter (= k+1). + show AllRowsFillFC.rows_step_post lm r_as_ntt error_1 result_init start k + (.cont (({ start := s_iter, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize), me1, rnew, scratch1, acc2)) + refine ⟨h_lt, rfl, hs_iter_val, ?_⟩ + show (AllRowsFillFC.rows_inv lm r_as_ntt error_1 result_init start + s_iter rnew scratch1 acc2).holds + unfold AllRowsFillFC.rows_inv + show (pure _ : Result Prop).holds + have hs_iter_eq : s_iter.val = k.val + 1 := hs_iter_val + have h_rnew_len : rnew.length = K.val := by + rw [h_rnew_def, Aeneas.Std.Slice.set_length, h_rslice2_def, + Aeneas.Std.Slice.set_length, h_rslice1_def, Aeneas.Std.Slice.set_length] + exact h_result_len + have h_inv_pure : + (∀ r : Nat, start.val ≤ r → r < s_iter.val → + AllRowsFillFC.row_spec lm r_as_ntt error_1 r = .ok (lift_poly (rnew.val[r]!))) + ∧ (∀ r : Nat, r < K.val → (r < start.val ∨ s_iter.val ≤ r) → + rnew.val[r]! = result_init.val[r]!) + ∧ rnew.length = K.val := by + refine ⟨?_, ?_, h_rnew_len⟩ + · -- Completed rows [start, k+1). + intro r hr_ge hr_lt + rw [hs_iter_eq] at hr_lt + rcases Nat.lt_succ_iff_lt_or_eq.mp hr_lt with hr_lt_k | hr_eq_k + · -- r < k: unchanged this iteration; use IH (1). + have hr_ne : r ≠ k.val := by omega + rw [h_rnew_ne r hr_ne] + exact h_inv_done r hr_ge hr_lt_k + · -- r = k: the row written this iteration. + subst hr_eq_k + rw [h_rnew_at] + exact h_row_spec + · -- Unchanged rows. + intro r hr_lt_K hr_cond + rw [hs_iter_eq] at hr_cond + have hr_ne : r ≠ k.val := by omega + rw [h_rnew_ne r hr_ne] + exact h_inv_undone r hr_lt_K (by omega) + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch (k = K): loop ends, done. + have hk_eq : k.val = K.val := le_antisymm h_le (Nat.not_lt.mp h_lt) + have h_iter_none : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize) + ⦃ ⇓ r => ⌜ r = (none, + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝ ⦄ := + libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize k K + (fun hlt => absurd hlt (Nat.not_lt.mpr (Nat.le_of_eq hk_eq.symm))) + (fun _ => by dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v_iter, hv_iter_eq, hv_iter_post⟩ := triple_exists_ok_fc h_iter_none + have h_body : + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1.body + K (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + i_zero seed r_as_ntt error_1 cache { start := k, «end» := K } matrix_entry result scratch + accumulator + = .ok (ControlFlow.done (matrix_entry, result, scratch, accumulator)) := by + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop1.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := K } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [hv_iter_post] at hv_iter_eq + rw [hv_iter_eq] + rfl + apply triple_of_ok_fc h_body + show AllRowsFillFC.rows_step_post lm r_as_ntt error_1 result_init start k + (.done (matrix_entry, result, scratch, accumulator)) + show (AllRowsFillFC.rows_inv lm r_as_ntt error_1 result_init start K result scratch + accumulator).holds + unfold AllRowsFillFC.rows_inv + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ r : Nat, start.val ≤ r → r < K.val → + AllRowsFillFC.row_spec lm r_as_ntt error_1 r = .ok (lift_poly (result.val[r]!))) + ∧ (∀ r : Nat, r < K.val → (r < start.val ∨ K.val ≤ r) → + result.val[r]! = result_init.val[r]!) + ∧ result.length = K.val := by + refine ⟨?_, ?_, h_result_len⟩ + · intro r hr_ge hr_lt + exact h_inv_done r hr_ge (by rw [hk_eq]; exact hr_lt) + · intro r hr_lt_K hr_cond + exact h_inv_undone r hr_lt_K (by + rcases hr_cond with h | h + · exact Or.inl h + · exact Or.inr (by rw [hk_eq]; exact h)) + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 1600000 in +/-- **L7.2 Stage 3 — outer rows loop FC** (`compute_vector_u_loop1`). Mirrors + `compute_As_plus_e_loop1_fc`: the rows loop `[start, K)`, + each row re-zeroing the accumulator, running the use-cache inner column loop, + and finalizing (reducing → invert → add_error). POST is the resolved + `AllRowsFillFC.rows_inv` at `k = K`: every row `r ∈ [start, K)` matches + `row_spec lm r_as_ntt error_1 r = .ok (lift_poly result[r])`, and rows + outside `[start, K)` are unchanged from the input. -/ +theorem compute_vector_u_loop1_fc {K : Std.Usize} {Hasher : Type} + (hash_functionsHashInst : libcrux_iot_ml_kem.hash_functions.Hash Hasher) + (i_zero : Std.I32) + (matrix_entry : AllRowsFillFC.Poly) (seed : Slice Std.U8) + (r_as_ntt error_1 result : Slice AllRowsFillFC.Poly) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (cache : Slice AllRowsFillFC.Poly) + (accumulator : AllRowsFillFC.Acc) + (r_arr : Std.Array AllRowsFillFC.Poly K) + (start : Std.Usize) + (hK : K.val ≤ 4) (h_start : 1 ≤ start.val) (h_start_le : start.val ≤ K.val) + (h_seed_len : seed.length = 32) (h_r_len : r_as_ntt.length = K.val) + (h_cache_len : cache.length = K.val) (h_result_len : result.length = K.val) + (h_err_len : error_1.length = K.val) + (h_i_zero : i_zero.val = 0) + (h_r_arr : ∀ c : Nat, c < K.val → r_arr.val[c]! = r_as_ntt.val[c]!) + (h_r_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((r_as_ntt.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 3328) + (h_err_bnd : ∀ c : Nat, c < K.val → ∀ a : Fin 16, ∀ b : Fin 16, + ((error_1.val[c]!.coefficients.val[a.val]!).elements.val[b.val]!).val.natAbs ≤ 29439) + (h_cache : ∀ c : Nat, c < K.val → + accumulating_ntt_multiply_poly_cache_post (r_as_ntt.val[c]!) (cache.val[c]!)) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1 + K (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + i_zero { start := start, «end» := K } matrix_entry seed r_as_ntt error_1 result scratch cache + accumulator + ⦃ ⇓ p => ⌜ (AllRowsFillFC.rows_inv (lift_matrix_from_seed seed K) r_as_ntt error_1 result start + K p.2.1 p.2.2.1 p.2.2.2).holds ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.matrix.compute_vector_u_loop1 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, p) => + libcrux_iot_ml_kem.matrix.compute_vector_u_loop1.body + K (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + i_zero seed r_as_ntt error_1 cache iter1 p.1 p.2.1 p.2.2.1 p.2.2.2) + (β := AllRowsFillFC.Poly × Slice AllRowsFillFC.Poly × AllRowsFillFC.Scratch × AllRowsFillFC.Acc) + (matrix_entry, result, scratch, accumulator) + start K + (fun k p => AllRowsFillFC.rows_inv (lift_matrix_from_seed seed K) r_as_ntt error_1 result start + k p.2.1 p.2.2.1 p.2.2.2) + h_start_le + (by + -- Base case at k = start: rows_inv holds trivially. + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, h_result_len⟩ + · intro r hr_ge hr_lt; omega + · intro r _ _; trivial) + ?_) + · -- Post entailment: at k = K, rows_inv holds. + rw [PostCond.entails_noThrow] + intro p hh + have h_inv_holds : (AllRowsFillFC.rows_inv (lift_matrix_from_seed seed K) r_as_ntt error_1 result + start K p.2.1 p.2.2.1 p.2.2.2).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + exact h_inv_holds + · -- Step entailment. + intro p k h_ge h_le hinv + have h_step := compute_vector_u_loop1_step_lemma_fc + hash_functionsHashInst i_zero seed r_as_ntt error_1 cache r_arr result start hK + h_seed_len h_r_len h_cache_len h_err_len h_i_zero h_r_arr h_r_bnd h_err_bnd h_cache + p.1 p.2.1 p.2.2.1 p.2.2.2 k h_ge h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', rest⟩ | y + · have hP : AllRowsFillFC.rows_step_post (lift_matrix_from_seed seed K) r_as_ntt error_1 + result start k (.cont (iter', rest.1, rest.2.1, rest.2.2.1, rest.2.2.2)) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [AllRowsFillFC.rows_step_post] using hP + · have hP : AllRowsFillFC.rows_step_post (lift_matrix_from_seed seed K) r_as_ntt error_1 + result start k (.done (y.1, y.2.1, y.2.2.1, y.2.2.2)) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [AllRowsFillFC.rows_step_post] using hP + +end libcrux_iot_ml_kem.Matrix.ComputeVectorU.Impl \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Ntt.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Ntt.lean new file mode 100644 index 00000000..e1e3e499 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Ntt.lean @@ -0,0 +1,4478 @@ +/- + # `Ntt.lean` — extracted from `FCTargets.lean` §ntt. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.Ntt +open libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §L3 — NTT driver loops (5 theorems). -/ + +/-! ### L3.0 — Helpers for the layer-N driver loops. + + `ZETAS_bound`, `polynomial.zeta_eq_ok_fc`, `polynomial.zeta_fc` + expose the impl `polynomial.zeta` as a deterministic `.ok` value + + Mont-domain bound. The chunk/flatten identities + (`Spec.chunk_at_lift_poly`, `Spec.flatten_chunks_eq_lift_poly`) + bridge `lift_poly` ↔ `Spec.chunk_at`/`Spec.flatten_chunks`. -/ + +unseal libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R in +theorem ZETAS_bound : + ∀ i : Nat, i < 128 → + ((libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R).val[i]!).val.natAbs + ≤ 1664 := by + intro i hi + interval_cases i <;> decide + +/-- Pure-projection: `polynomial.zeta i` reduces to the array lookup. -/ +theorem polynomial.zeta_eq_ok_fc + (i : Std.Usize) (hi : i.val < 128) : + libcrux_iot_ml_kem.polynomial.zeta i + = .ok ((libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R).val[i.val]!) := by + unfold libcrux_iot_ml_kem.polynomial.zeta + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R i (by + rw [show libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.length = 128 + from Std.Array.length_eq _]; exact hi) + +/-- FC-style Triple for `polynomial.zeta`: returns the exact lookup + value, with a Mont-domain absolute-value bound and the canonical-domain + `lift_fe_mont` lift equal to `Spec.zeta_at i.val`. -/ +@[spec high] +theorem polynomial.zeta_fc (i : Std.Usize) (hi : i.val < 128) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.zeta i + ⦃ ⇓ r => ⌜ r = (libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R).val[i.val]! + ∧ r.val.natAbs ≤ 1664 + ∧ lift_fe_mont r = Spec.zeta_at i.val ⌝ ⦄ := by + apply triple_of_ok_fc (polynomial.zeta_eq_ok_fc i hi) + refine ⟨rfl, ZETAS_bound i.val hi, ?_⟩ + unfold Spec.zeta_at + rfl + + +/-! ### L3.1.A — Loop scaffolding for `ntt_at_layer_1_portable_fc`. + + Strengthened FC invariant for the 16-iter driver loop. Each iteration: + (1) advances `zeta_i` by 4 (4 zeta lookups per chunk: positions + `zeta_i + 4k + {1..4}`), + (2) records the FC equation `lift_chunk acc.2[j] = Spec.chunk_ntt_layer_1_step_pure + (lift_chunk re.coefs[j]) (Spec.zeta_at (zeta_i + 4j + ⋅))` + for `j < k.val` (chunks already processed), + (3) preserves `acc.2.coefficients[j] = re.coefficients[j]` for `j ≥ k.val` + (chunks not yet processed). + + The step lemma chains the body's 9 sub-ops (zeta_i+1, index_mut, 4× zeta, + 3× usize_add, ntt_layer_1_step) using `polynomial.zeta_fc` and + `ntt_layer_1_step_fc` (both `@[spec]`-tagged). -/ + +namespace Layer1FC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Local `usize_add_ok_eq` helper (mirrors `Equivalence/`). -/ +theorem usize_add_ok_eq (x y : Std.Usize) + (h_max : x.val + y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x + y : Result Std.Usize) = .ok z ∧ z.val = x.val + y.val := by + have hT := Std.Usize.add_spec h_max + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v⟩ + +/-- Step-local accumulator. -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `ntt_at_layer_1_portable_fc`. -/ +def inv + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = zeta_i_0.val + 4 * k.val + ∧ (∀ j : Nat, j < k.val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_1_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 1)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 2)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 3)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 4))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv zeta_i_0 re iter'.start acc').holds + | .done y => (inv zeta_i_0 re 16#usize y).holds + +end Layer1FC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma. Given a valid loop state `(acc, k)` with + `k.val < 16`, advances `zeta_i` by 4 and records the FC equation for + chunk `k.val`, leaving chunks `> k.val` unchanged. -/ +theorem ntt_at_layer_1_step_lemma_fc + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (h_zeta_bnd : zeta_i_0.val + 64 ≤ 127) + (acc : Layer1FC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (Layer1FC.inv zeta_i_0 re k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer1FC.step_post zeta_i_0 re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some round = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) `zeta_i + 1`. Bound: acc.1.val ≤ zeta_i_0.val + 64-4 = zeta_i_0+60 ≤ 124. + have h_acc1_lt : acc.1.val + 4 ≤ zeta_i_0.val + 64 := by + rw [h_zeta_acc] + have h_k_le : 4 * k.val ≤ 60 := by omega + omega + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_um2 : (2#usize : Std.Usize).val = 2 := rfl + have h_um3 : (3#usize : Std.Usize).val = 3 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + Layer1FC.usize_add_ok_eq acc.1 1#usize h_z_max + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by + rw [h_zi1_val_arith, h_zeta_acc]; omega + -- (2) `index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx]; rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + -- (3) `polynomial.zeta zi1`. + obtain ⟨z1, h_z1_eq, h_z1_v, h_z1_bd, h_z1_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi1 h_zi1_lt) + -- (4) `zi1 + 1`. + have h_zi3_max : zi1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi3, h_zi3_eq, h_zi3_val⟩ := + Layer1FC.usize_add_ok_eq zi1 1#usize h_zi3_max + have h_zi3_val_arith : zi3.val = acc.1.val + 2 := by + rw [h_zi3_val, h_um, h_zi1_val_arith] + have h_zi3_lt : zi3.val < 128 := by + rw [h_zi3_val_arith, h_zeta_acc]; omega + -- (5) `polynomial.zeta zi3`. + obtain ⟨z2, h_z2_eq, h_z2_v, h_z2_bd, h_z2_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi3 h_zi3_lt) + -- (6) `zi1 + 2`. + have h_zi5_max : zi1.val + (2#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um2]; scalar_tac + obtain ⟨zi5, h_zi5_eq, h_zi5_val⟩ := + Layer1FC.usize_add_ok_eq zi1 2#usize h_zi5_max + have h_zi5_val_arith : zi5.val = acc.1.val + 3 := by + rw [h_zi5_val, h_um2, h_zi1_val_arith] + have h_zi5_lt : zi5.val < 128 := by + rw [h_zi5_val_arith, h_zeta_acc]; omega + -- (7) `polynomial.zeta zi5`. + obtain ⟨z3, h_z3_eq, h_z3_v, h_z3_bd, h_z3_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi5 h_zi5_lt) + -- (8) `zi1 + 3`. + have h_zi7_max : zi1.val + (3#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um3]; scalar_tac + obtain ⟨zi7, h_zi7_eq, h_zi7_val⟩ := + Layer1FC.usize_add_ok_eq zi1 3#usize h_zi7_max + have h_zi7_val_arith : zi7.val = acc.1.val + 4 := by + rw [h_zi7_val, h_um3, h_zi1_val_arith] + have h_zi7_lt : zi7.val < 128 := by + rw [h_zi7_val_arith, h_zeta_acc]; omega + -- (9) `polynomial.zeta zi7`. + obtain ⟨z4, h_z4_eq, h_z4_v, h_z4_bd, h_z4_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi7 h_zi7_lt) + -- (10) `ntt_layer_1_step t z1 z2 z3 z4`. Pre: t's lanes ≤ 29439 (via h_pre + undone). + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + -- @[reducible] portable_ops_inst forwards to vector.portable.ntt.ntt_layer_1_step. + -- ntt_layer_1_step_fc consumes (vec, z0..z3, hz, hvec). + obtain ⟨t1, h_t1_eq, h_t1_lift⟩ := + triple_exists_ok_fc (ntt_layer_1_step_fc t z1 z2 z3 z4 + ⟨h_z1_bd, h_z2_bd, h_z3_bd, h_z4_bd⟩ h_t_bd) + -- Compose entire body. + set acc' : Layer1FC.Acc := (zi7, { coefficients := acc.2.coefficients.set k t1 }) + with hacc'_def + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [Aeneas.Std.bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq, h_zi3_eq, + h_z2_eq, h_zi5_eq, h_z3_eq, h_zi7_eq, h_z4_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step t z1 z2 z3 z4 + Result.ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi7, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq]; rfl + apply triple_of_ok_fc h_body + show Layer1FC.step_post zeta_i_0 re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer1FC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + -- Invariant at (s, acc'). + show (Layer1FC.inv zeta_i_0 re s acc').holds + have h_inv_pure : + acc'.1.val = zeta_i_0.val + 4 * s.val + ∧ (∀ j : Nat, j < s.val → + lift_chunk (acc'.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_1_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 1)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 2)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 3)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 4))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.2.coefficients.val[j]! = re.coefficients.val[j]!) := by + refine ⟨?_, ?_, ?_⟩ + · -- acc'.1 = zi7, zi7.val = acc.1.val + 4 = zeta_i_0.val + 4 * (k.val + 1). + show zi7.val = zeta_i_0.val + 4 * s.val + rw [h_zi7_val_arith, h_zeta_acc, hs_val]; ring + · -- All j < s.val are FC-equal. + intro j hj + rw [hs_val] at hj + -- acc'.2.coefficients = acc.2.coefficients.set k t1. + show lift_chunk ((acc.2.coefficients.set k t1).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: unchanged by set; use h_acc_done. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k + · -- j = k.val: it's t1; use h_t1_lift + h_t_eq + zeta_lift identities. + subst hj_eq_k + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq_val, h_t1_lift, h_t_eq] + -- Need: Spec.chunk_ntt_layer_1_step_pure (lift_chunk re.coefficients[k]) (lift_fe_mont z1..z4) + -- = Spec.chunk_ntt_layer_1_step_pure (lift_chunk re.coefficients[k]) + -- (Spec.zeta_at (zeta_i_0 + 4*k + 1..4)). + -- Use h_z1_lift..h_z4_lift to rewrite lift_fe_mont zi → Spec.zeta_at zi.val. + have h_zi1_z : zi1.val = zeta_i_0.val + 4 * k.val + 1 := by + rw [h_zi1_val_arith, h_zeta_acc] + have h_zi3_z : zi3.val = zeta_i_0.val + 4 * k.val + 2 := by + rw [h_zi3_val_arith, h_zeta_acc] + have h_zi5_z : zi5.val = zeta_i_0.val + 4 * k.val + 3 := by + rw [h_zi5_val_arith, h_zeta_acc] + have h_zi7_z : zi7.val = zeta_i_0.val + 4 * k.val + 4 := by + rw [h_zi7_val_arith, h_zeta_acc] + rw [show lift_fe_mont z1 = Spec.zeta_at (zeta_i_0.val + 4 * k.val + 1) + from by rw [← h_zi1_z]; exact h_z1_lift] + rw [show lift_fe_mont z2 = Spec.zeta_at (zeta_i_0.val + 4 * k.val + 2) + from by rw [← h_zi3_z]; exact h_z2_lift] + rw [show lift_fe_mont z3 = Spec.zeta_at (zeta_i_0.val + 4 * k.val + 3) + from by rw [← h_zi5_z]; exact h_z3_lift] + rw [show lift_fe_mont z4 = Spec.zeta_at (zeta_i_0.val + 4 * k.val + 4) + from by rw [← h_zi7_z]; exact h_z4_lift] + · -- All j ≥ s.val are unchanged. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + -- inv .. = pure (P) with .holds reducing to P. + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer1FC.step_post zeta_i_0 re k (.done acc) + unfold Layer1FC.step_post + show (Layer1FC.inv zeta_i_0 re 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + acc.1.val = zeta_i_0.val + 4 * (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_1_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 1)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 2)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 3)) + (Spec.zeta_at (zeta_i_0.val + 4 * j + 4))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) := by + refine ⟨?_, ?_, ?_⟩ + · rw [h_zeta_acc, hk_eq, h16] + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L3.1' — `ntt_at_layer_1` PortableVector-specialised FC equation. + The impl returns `(zeta_i_after, re_after)`; we project on `re_after`. + + **Preconditions** (load-bearing, beyond the locked True-pre form): + - `h_bnd` : per-lane input bound 29439 across all 16 chunks × 16 lanes. + - `h_zeta : zeta_i.val + 64 ≤ 127` — strengthened from original `≤ 128` + to ensure all zeta indices `zeta_i+1 .. zeta_i+64` are < 128 (OOB + check). Original `≤ 128` permitted `zeta_i.val = 64`, OOB on last iter. -/ +@[spec high] +theorem ntt_at_layer_1_portable_fc + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_bound : Std.Usize) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 29439) + (h_zeta : zeta_i.val + 64 ≤ 127) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_1 + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i re initial_bound + ⦃ ⇓ p => ⌜ lift_poly p.2 = Spec.ntt_layer_1_pure (lift_poly re) zeta_i ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1 + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + iter1 acc1.1 acc1.2) + (β := Layer1FC.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer1FC.inv zeta_i re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_⟩ + · -- zeta-thread invariant at k=0. + show zeta_i.val = zeta_i.val + 4 * (0#usize : Std.Usize).val + show zeta_i.val = zeta_i.val + 4 * 0 + omega + · -- No chunks done yet. + intro j hj + exact absurd hj (Nat.not_lt_zero j) + · -- All chunks unchanged; goal collapses to True after simp. + intro _ _ _ + trivial) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations. + rw [PostCond.entails_noThrow] + intro r hh + -- Manually extract the inv payload (avoid `simp` aggression on `[!]`). + have h_inv_holds : (Layer1FC.inv zeta_i re 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + r.1.val = zeta_i.val + 4 * (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_1_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i.val + 4 * j + 1)) + (Spec.zeta_at (zeta_i.val + 4 * j + 2)) + (Spec.zeta_at (zeta_i.val + 4 * j + 3)) + (Spec.zeta_at (zeta_i.val + 4 * j + 4))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.2.coefficients.val[j]! = re.coefficients.val[j]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer1FC.inv] using h_inv_holds + obtain ⟨_h_zeta_eq, h_done, _h_undone⟩ := h_inv + have h16 : (16#usize : Std.Usize).val = 16 := rfl + -- `Spec.ntt_layer_1_pure (lift_poly re) zeta_i` unfolds to + -- `flatten_chunks (Array.make 16 ((List.range 16).map (fun k => + -- chunk_ntt_layer_1_step_pure (chunk_at (lift_poly re) k) (zeta_at ...))))`. + -- Show that the chunks array equals the `lift_chunk r.2.coefficients[k]` family + -- via `h_done` + `chunk_at_lift_poly_fc`, then `flatten_chunks_eq_lift_poly_fc`. + unfold Spec.ntt_layer_1_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_ntt_layer_1_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val + 4 * k + 1)) + (Spec.zeta_at (zeta_i.val + 4 * k + 2)) + (Spec.zeta_at (zeta_i.val + 4 * k + 3)) + (Spec.zeta_at (zeta_i.val + 4 * k + 4)))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.2.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_ntt_layer_1_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val + 4 * k + 1)) + (Spec.zeta_at (zeta_i.val + 4 * k + 2)) + (Spec.zeta_at (zeta_i.val + 4 * k + 3)) + (Spec.zeta_at (zeta_i.val + 4 * k + 4))))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc re k hk] + exact (h_done k hk).symm + -- Apply flatten_chunks_eq_lift_poly_fc (with `r.2` as the poly). + have h_final := flatten_chunks_eq_lift_poly_fc r.2 chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Step lemma application: dispatch ntt_at_layer_1_step_lemma_fc. + intro acc k _h_ge h_le hinv + have h_step := ntt_at_layer_1_step_lemma_fc zeta_i re h_bnd h_zeta acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer1FC.step_post zeta_i re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer1FC.step_post] using hP + · have hP : Layer1FC.step_post zeta_i re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer1FC.step_post] using hP + +/-! ### L3.2.A — Loop scaffolding for `ntt_at_layer_2_portable_fc`. + + Strengthened FC invariant for the 16-iter driver loop. Each iteration: + (1) advances `zeta_i` by 2 (2 zeta lookups per chunk: positions + `zeta_i + 2k + {1, 2}`), + (2) records the FC equation `lift_chunk acc.2[j] = Spec.chunk_ntt_layer_2_step_pure + (lift_chunk re.coefs[j]) (Spec.zeta_at (zeta_i + 2j + ⋅))` + for `j < k.val` (chunks already processed), + (3) preserves `acc.2.coefficients[j] = re.coefficients[j]` for `j ≥ k.val` + (chunks not yet processed). + + The step lemma chains the body's 6 sub-ops (zeta_i+1, index_mut, 2× zeta, + 1× usize_add, ntt_layer_2_step) using `polynomial.zeta_fc` and + `ntt_layer_2_step_fc` (both `@[spec]`-tagged). -/ + +namespace Layer2FC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Local `usize_add_ok_eq` helper (mirrors `Layer1FC.usize_add_ok_eq`). -/ +theorem usize_add_ok_eq (x y : Std.Usize) + (h_max : x.val + y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x + y : Result Std.Usize) = .ok z ∧ z.val = x.val + y.val := by + have hT := Std.Usize.add_spec h_max + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v⟩ + +/-- Step-local accumulator. -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `ntt_at_layer_2_portable_fc`. -/ +def inv + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = zeta_i_0.val + 2 * k.val + ∧ (∀ j : Nat, j < k.val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_2_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val + 2 * j + 1)) + (Spec.zeta_at (zeta_i_0.val + 2 * j + 2))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv zeta_i_0 re iter'.start acc').holds + | .done y => (inv zeta_i_0 re 16#usize y).holds + +end Layer2FC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for layer 2. Given a valid loop state + `(acc, k)` with `k.val < 16`, advances `zeta_i` by 2 and records the + FC equation for chunk `k.val`, leaving chunks `> k.val` unchanged. -/ +theorem ntt_at_layer_2_step_lemma_fc + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (h_zeta_bnd : zeta_i_0.val + 32 ≤ 127) + (acc : Layer2FC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (Layer2FC.inv zeta_i_0 re k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer2FC.step_post zeta_i_0 re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some round = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) `zeta_i + 1`. Bound: acc.1.val ≤ zeta_i_0.val + 32-2 = zeta_i_0+30 ≤ 125. + have h_acc1_lt : acc.1.val + 2 ≤ zeta_i_0.val + 32 := by + rw [h_zeta_acc] + have h_k_le : 2 * k.val ≤ 30 := by omega + omega + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + Layer2FC.usize_add_ok_eq acc.1 1#usize h_z_max + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by + rw [h_zi1_val_arith, h_zeta_acc]; omega + -- (2) `index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx]; rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + -- (3) `polynomial.zeta zi1`. + obtain ⟨z1, h_z1_eq, h_z1_v, h_z1_bd, h_z1_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi1 h_zi1_lt) + -- (4) `zi1 + 1`. + have h_zi3_max : zi1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi3, h_zi3_eq, h_zi3_val⟩ := + Layer2FC.usize_add_ok_eq zi1 1#usize h_zi3_max + have h_zi3_val_arith : zi3.val = acc.1.val + 2 := by + rw [h_zi3_val, h_um, h_zi1_val_arith] + have h_zi3_lt : zi3.val < 128 := by + rw [h_zi3_val_arith, h_zeta_acc]; omega + -- (5) `polynomial.zeta zi3`. + obtain ⟨z2, h_z2_eq, h_z2_v, h_z2_bd, h_z2_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi3 h_zi3_lt) + -- (6) `ntt_layer_2_step t z1 z2`. Pre: t's lanes ≤ 29439 (via h_pre + undone). + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + -- @[reducible] portable_ops_inst forwards to vector.portable.ntt.ntt_layer_2_step. + -- ntt_layer_2_step_fc consumes (vec, z0, z1, hz, hvec). + obtain ⟨t1, h_t1_eq, h_t1_lift⟩ := + triple_exists_ok_fc (ntt_layer_2_step_fc t z1 z2 + ⟨h_z1_bd, h_z2_bd⟩ h_t_bd) + -- Compose entire body. + set acc' : Layer2FC.Acc := (zi3, { coefficients := acc.2.coefficients.set k t1 }) + with hacc'_def + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [Aeneas.Std.bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq, h_zi3_eq, h_z2_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step t z1 z2 + Result.ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi3, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq]; rfl + apply triple_of_ok_fc h_body + show Layer2FC.step_post zeta_i_0 re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer2FC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + -- Invariant at (s, acc'). + show (Layer2FC.inv zeta_i_0 re s acc').holds + have h_inv_pure : + acc'.1.val = zeta_i_0.val + 2 * s.val + ∧ (∀ j : Nat, j < s.val → + lift_chunk (acc'.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_2_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val + 2 * j + 1)) + (Spec.zeta_at (zeta_i_0.val + 2 * j + 2))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.2.coefficients.val[j]! = re.coefficients.val[j]!) := by + refine ⟨?_, ?_, ?_⟩ + · -- acc'.1 = zi3, zi3.val = acc.1.val + 2 = zeta_i_0.val + 2 * (k.val + 1). + show zi3.val = zeta_i_0.val + 2 * s.val + rw [h_zi3_val_arith, h_zeta_acc, hs_val]; ring + · -- All j < s.val are FC-equal. + intro j hj + rw [hs_val] at hj + -- acc'.2.coefficients = acc.2.coefficients.set k t1. + show lift_chunk ((acc.2.coefficients.set k t1).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: unchanged by set; use h_acc_done. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k + · -- j = k.val: it's t1; use h_t1_lift + h_t_eq + zeta_lift identities. + subst hj_eq_k + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq_val, h_t1_lift, h_t_eq] + have h_zi1_z : zi1.val = zeta_i_0.val + 2 * k.val + 1 := by + rw [h_zi1_val_arith, h_zeta_acc] + have h_zi3_z : zi3.val = zeta_i_0.val + 2 * k.val + 2 := by + rw [h_zi3_val_arith, h_zeta_acc] + rw [show lift_fe_mont z1 = Spec.zeta_at (zeta_i_0.val + 2 * k.val + 1) + from by rw [← h_zi1_z]; exact h_z1_lift] + rw [show lift_fe_mont z2 = Spec.zeta_at (zeta_i_0.val + 2 * k.val + 2) + from by rw [← h_zi3_z]; exact h_z2_lift] + · -- All j ≥ s.val are unchanged. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + -- inv .. = pure (P) with .holds reducing to P. + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer2FC.step_post zeta_i_0 re k (.done acc) + unfold Layer2FC.step_post + show (Layer2FC.inv zeta_i_0 re 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + acc.1.val = zeta_i_0.val + 2 * (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_2_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val + 2 * j + 1)) + (Spec.zeta_at (zeta_i_0.val + 2 * j + 2))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) := by + refine ⟨?_, ?_, ?_⟩ + · rw [h_zeta_acc, hk_eq, h16] + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L3.2 — `ntt_at_layer_2` PortableVector-specialised FC equation. + The impl returns `(zeta_i_after, re_after)`; we project on `re_after`. + + **Preconditions** (load-bearing, beyond the locked True-pre form): + - `h_bnd` : per-lane input bound 29439 across all 16 chunks × 16 lanes. + - `h_zeta : zeta_i.val + 32 ≤ 127` — strengthened from original `≤ 128` + to ensure all zeta indices `zeta_i+1 .. zeta_i+32` are < 128 (OOB + check). Original `≤ 128` permitted `zeta_i.val = 96`, OOB on last + iter (index 128 = ZETAS table length). -/ +@[spec high] +theorem ntt_at_layer_2_portable_fc + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_bound : Std.Usize) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 29439) + (h_zeta : zeta_i.val + 32 ≤ 127) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_2 + (vectortraitsOperationsInst := portable_ops_inst) zeta_i re initial_bound + ⦃ ⇓ p => ⌜ lift_poly p.2 = Spec.ntt_layer_2_pure (lift_poly re) zeta_i ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2 + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + iter1 acc1.1 acc1.2) + (β := Layer2FC.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer2FC.inv zeta_i re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_⟩ + · -- zeta-thread invariant at k=0. + show zeta_i.val = zeta_i.val + 2 * (0#usize : Std.Usize).val + show zeta_i.val = zeta_i.val + 2 * 0 + omega + · -- No chunks done yet. + intro j hj + exact absurd hj (Nat.not_lt_zero j) + · -- All chunks unchanged; goal collapses to True after simp. + intro _ _ _ + trivial) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Layer2FC.inv zeta_i re 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + r.1.val = zeta_i.val + 2 * (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_2_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i.val + 2 * j + 1)) + (Spec.zeta_at (zeta_i.val + 2 * j + 2))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.2.coefficients.val[j]! = re.coefficients.val[j]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer2FC.inv] using h_inv_holds + obtain ⟨_h_zeta_eq, h_done, _h_undone⟩ := h_inv + have h16 : (16#usize : Std.Usize).val = 16 := rfl + unfold Spec.ntt_layer_2_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_ntt_layer_2_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val + 2 * k + 1)) + (Spec.zeta_at (zeta_i.val + 2 * k + 2)))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.2.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_ntt_layer_2_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val + 2 * k + 1)) + (Spec.zeta_at (zeta_i.val + 2 * k + 2))))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc re k hk] + exact (h_done k hk).symm + -- Apply flatten_chunks_eq_lift_poly_fc (with `r.2` as the poly). + have h_final := flatten_chunks_eq_lift_poly_fc r.2 chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Step lemma application: dispatch ntt_at_layer_2_step_lemma_fc. + intro acc k _h_ge h_le hinv + have h_step := ntt_at_layer_2_step_lemma_fc zeta_i re h_bnd h_zeta acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer2FC.step_post zeta_i re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer2FC.step_post] using hP + · have hP : Layer2FC.step_post zeta_i re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer2FC.step_post] using hP + +/-! ### L3.3'.A — Loop scaffolding for `ntt_at_layer_3_portable_fc`. + + Strengthened FC invariant for the 16-iter driver loop. Each iteration: + (1) advances `zeta_i` by 1 (1 zeta lookup per chunk: position + `zeta_i + k + 1`), + (2) records the FC equation `lift_chunk acc.2[j] = + Spec.chunk_ntt_layer_3_step_pure (lift_chunk re.coefs[j]) + (Spec.zeta_at (zeta_i + j + 1))` for `j < k.val`, + (3) preserves `acc.2.coefficients[j] = re.coefficients[j]` for `j ≥ k.val`. + + The step lemma chains the body's 4 sub-ops (zeta_i+1, index_mut, 1× zeta, + ntt_layer_3_step) using `polynomial.zeta_fc` and `ntt_layer_3_step_fc`. -/ + +namespace Layer3FC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Local `usize_add_ok_eq` helper (mirrors `Layer2FC.usize_add_ok_eq`). -/ +theorem usize_add_ok_eq (x y : Std.Usize) + (h_max : x.val + y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x + y : Result Std.Usize) = .ok z ∧ z.val = x.val + y.val := by + have hT := Std.Usize.add_spec h_max + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v⟩ + +/-- Step-local accumulator. -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `ntt_at_layer_3_portable_fc`. -/ +def inv + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = zeta_i_0.val + k.val + ∧ (∀ j : Nat, j < k.val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_3_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val + j + 1))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv zeta_i_0 re iter'.start acc').holds + | .done y => (inv zeta_i_0 re 16#usize y).holds + +end Layer3FC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for layer 3. Given a valid loop state + `(acc, k)` with `k.val < 16`, advances `zeta_i` by 1 and records the + FC equation for chunk `k.val`, leaving chunks `> k.val` unchanged. -/ +theorem ntt_at_layer_3_step_lemma_fc + (zeta_i_0 : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (h_zeta_bnd : zeta_i_0.val + 16 ≤ 127) + (acc : Layer3FC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (Layer3FC.inv zeta_i_0 re k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer3FC.step_post zeta_i_0 re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some round = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) `zeta_i + 1`. Bound: acc.1.val ≤ zeta_i_0.val + 16-1 = zeta_i_0+15 ≤ 126. + have h_acc1_lt : acc.1.val + 1 ≤ zeta_i_0.val + 16 := by + rw [h_zeta_acc] + have h_k_le : k.val ≤ 15 := by omega + omega + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + Layer3FC.usize_add_ok_eq acc.1 1#usize h_z_max + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by + rw [h_zi1_val_arith, h_zeta_acc]; omega + -- (2) `index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx]; rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + -- (3) `polynomial.zeta zi1`. + obtain ⟨z1, h_z1_eq, h_z1_v, h_z1_bd, h_z1_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zi1 h_zi1_lt) + -- (4) `ntt_layer_3_step t z1`. Pre: t's lanes ≤ 29439 (via h_pre + undone). + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + -- @[reducible] portable_ops_inst forwards to vector.portable.ntt.ntt_layer_3_step. + -- ntt_layer_3_step_fc consumes (vec, z, hz, hvec). + obtain ⟨t1, h_t1_eq, h_t1_lift⟩ := + triple_exists_ok_fc (ntt_layer_3_step_fc t z1 h_z1_bd h_t_bd) + -- Compose entire body. + set acc' : Layer3FC.Acc := (zi1, { coefficients := acc.2.coefficients.set k t1 }) + with hacc'_def + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [Aeneas.Std.bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step t z1 + Result.ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi1, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq]; rfl + apply triple_of_ok_fc h_body + show Layer3FC.step_post zeta_i_0 re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer3FC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + -- Invariant at (s, acc'). + show (Layer3FC.inv zeta_i_0 re s acc').holds + have h_inv_pure : + acc'.1.val = zeta_i_0.val + s.val + ∧ (∀ j : Nat, j < s.val → + lift_chunk (acc'.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_3_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val + j + 1))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.2.coefficients.val[j]! = re.coefficients.val[j]!) := by + refine ⟨?_, ?_, ?_⟩ + · -- acc'.1 = zi1, zi1.val = acc.1.val + 1 = zeta_i_0.val + (k.val + 1). + show zi1.val = zeta_i_0.val + s.val + rw [h_zi1_val_arith, h_zeta_acc, hs_val]; ring + · -- All j < s.val are FC-equal. + intro j hj + rw [hs_val] at hj + -- acc'.2.coefficients = acc.2.coefficients.set k t1. + show lift_chunk ((acc.2.coefficients.set k t1).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: unchanged by set; use h_acc_done. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k + · -- j = k.val: it's t1; use h_t1_lift + h_t_eq + zeta_lift identities. + subst hj_eq_k + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq_val, h_t1_lift, h_t_eq] + have h_zi1_z : zi1.val = zeta_i_0.val + k.val + 1 := by + rw [h_zi1_val_arith, h_zeta_acc] + rw [show lift_fe_mont z1 = Spec.zeta_at (zeta_i_0.val + k.val + 1) + from by rw [← h_zi1_z]; exact h_z1_lift] + · -- All j ≥ s.val are unchanged. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + -- inv .. = pure (P) with .holds reducing to P. + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer3FC.step_post zeta_i_0 re k (.done acc) + unfold Layer3FC.step_post + show (Layer3FC.inv zeta_i_0 re 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + acc.1.val = zeta_i_0.val + (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_3_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i_0.val + j + 1))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) := by + refine ⟨?_, ?_, ?_⟩ + · rw [h_zeta_acc, hk_eq, h16] + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L3.3' — `ntt_at_layer_3` PortableVector-specialised FC equation. + The impl returns `(zeta_i_after, re_after)`; we project on `re_after`. + + **Preconditions** (load-bearing, beyond the locked True-pre form): + - `h_bnd` : per-lane input bound 29439 across all 16 chunks × 16 lanes. + - `h_zeta : zeta_i.val + 16 ≤ 127` — ensures all zeta indices + `zeta_i+1 .. zeta_i+16` are < 128 (OOB check on ZETAS table). -/ +@[spec high] +theorem ntt_at_layer_3_portable_fc + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_bound : Std.Usize) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 29439) + (h_zeta : zeta_i.val + 16 ≤ 127) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_3 + (vectortraitsOperationsInst := portable_ops_inst) zeta_i re initial_bound + ⦃ ⇓ p => ⌜ lift_poly p.2 = Spec.ntt_layer_3_pure (lift_poly re) zeta_i ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3 + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + iter1 acc1.1 acc1.2) + (β := Layer3FC.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer3FC.inv zeta_i re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_⟩ + · -- zeta-thread invariant at k=0. + show zeta_i.val = zeta_i.val + (0#usize : Std.Usize).val + show zeta_i.val = zeta_i.val + 0 + omega + · -- No chunks done yet. + intro j hj + exact absurd hj (Nat.not_lt_zero j) + · -- All chunks unchanged; goal collapses to True after simp. + intro _ _ _ + trivial) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Layer3FC.inv zeta_i re 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + r.1.val = zeta_i.val + (16#usize : Std.Usize).val + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.2.coefficients.val[j]!) + = Spec.chunk_ntt_layer_3_step_pure + (lift_chunk (re.coefficients.val[j]!)) + (Spec.zeta_at (zeta_i.val + j + 1))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.2.coefficients.val[j]! = re.coefficients.val[j]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer3FC.inv] using h_inv_holds + obtain ⟨_h_zeta_eq, h_done, _h_undone⟩ := h_inv + have h16 : (16#usize : Std.Usize).val = 16 := rfl + unfold Spec.ntt_layer_3_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_ntt_layer_3_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val + k + 1)))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.2.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_ntt_layer_3_step_pure (Spec.chunk_at (lift_poly re) k) + (Spec.zeta_at (zeta_i.val + k + 1))))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc re k hk] + exact (h_done k hk).symm + -- Apply flatten_chunks_eq_lift_poly_fc (with `r.2` as the poly). + have h_final := flatten_chunks_eq_lift_poly_fc r.2 chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Step lemma application: dispatch ntt_at_layer_3_step_lemma_fc. + intro acc k _h_ge h_le hinv + have h_step := ntt_at_layer_3_step_lemma_fc zeta_i re h_bnd h_zeta acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer3FC.step_post zeta_i re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer3FC.step_post] using hP + · have hP : Layer3FC.step_post zeta_i re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer3FC.step_post] using hP + +/-! ### L3.7.A — Loop scaffolding for `ntt_at_layer_7_portable_fc`. + + Strengthened FC invariant for the 8-iter chunk-pair butterfly loop. + Each iteration j ∈ 0..8 butterflies chunks at positions `(j, j+8)` + with the constant Mont-form zeta `-1600`. Body sub-ops (10 total): + load `re[j+8]`, multiply by `-1600`, load `re[j]`, write t to slot + `j+8`, index_mut at j, add, index_mut at j+8, sub. + + The invariant tracks four clauses on the accumulator + `(re_acc, scratch_acc)`: + (a) chunks `j < k`: a-side butterflied (chunk_pair_butterfly_a_pure + of (re[j], re[j+8]) at zeta `Spec.zeta_layer_7`). + (b) chunks `j+8` for `j < k`: b-side butterflied + (chunk_pair_butterfly_b_pure of (re[j], re[j+8])). + (c) chunks `j` for `k ≤ j < 8`: unchanged. + (d) chunks `j+8` for `k ≤ j < 8`: unchanged. -/ + +namespace Layer7FC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Local `usize_add_ok_eq` helper (mirrors `Layer2FC.usize_add_ok_eq`). -/ +theorem usize_add_ok_eq (x y : Std.Usize) + (h_max : x.val + y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x + y : Result Std.Usize) = .ok z ∧ z.val = x.val + y.val := by + have hT := Std.Usize.add_spec h_max + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v⟩ + +/-- `IteratorRange.next` with end-bound 8 — `Some` branch. -/ +theorem iter_next8_some_eq (i : Std.Usize) + (h_lt : i.val < (8#usize : Std.Usize).val) : + ∃ s : Std.Usize, s.val = i.val + 1 ∧ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize) + = .ok (some i, + ({ start := s, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize)) := by + have hT := libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize i 8#usize + (Q := PostCond.noThrow fun (oi : Option Std.Usize × _) => ⌜ + ∃ s : Std.Usize, s.val = i.val + 1 + ∧ oi = (some i, + ({ start := s, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝) + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + have hex : ∃ v, core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize) = .ok v + ∧ (∃ s : Std.Usize, s.val = i.val + 1 + ∧ v = (some i, + ({ start := s, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize))) := by + generalize hx : core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize) = X at hT + match X, hT with + | .ok v, hT => exact ⟨v, rfl, by simpa [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply] using hT⟩ + | .fail _, hT => exact absurd hT (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div, hT => exact absurd hT (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + obtain ⟨v, hveq, s, hs_val, hpair⟩ := hex + exact ⟨s, hs_val, by rw [hveq, hpair]⟩ + +/-- `IteratorRange.next` with end-bound 8 — `None` branch. -/ +theorem iter_next8_none_eq (i : Std.Usize) + (h_ge : i.val ≥ (8#usize : Std.Usize).val) : + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize) + = .ok ((none : Option Std.Usize), + ({ start := i, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize)) := by + have hT := libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize i 8#usize + (Q := PostCond.noThrow fun (oi : Option Std.Usize × _) => ⌜ + oi = ((none : Option Std.Usize), + ({ start := i, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝) + (fun hlt => absurd hlt (Nat.not_lt.mpr h_ge)) + (fun _ => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + generalize hx : core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize) = X at hT + match X, hT with + | .ok v, hT => + have hP : v = ((none : Option Std.Usize), + ({ start := i, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize)) := by + simpa [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply] using hT + rw [hP] + | .fail _, hT => exact absurd hT (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div, hT => exact absurd hT (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- Step-local accumulator: `(re_acc, scratch_acc)`. -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector × + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `ntt_at_layer_7_portable_fc`. -/ +def inv + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + -- (a) chunks j < k: a-side butterfly result. + (∀ j : Nat, j < k.val → + lift_chunk (acc.1.coefficients.val[j]!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re.coefficients.val[j]!)) + (lift_chunk (re.coefficients.val[j + 8]!)) + Spec.zeta_layer_7) + -- (b) chunks j+8 for j < k: b-side butterfly result. + ∧ (∀ j : Nat, j < k.val → + lift_chunk (acc.1.coefficients.val[j + 8]!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re.coefficients.val[j]!)) + (lift_chunk (re.coefficients.val[j + 8]!)) + Spec.zeta_layer_7) + -- (c) chunks j for k ≤ j < 8: unchanged. + ∧ (∀ j : Nat, k.val ≤ j → j < 8 → + acc.1.coefficients.val[j]! = re.coefficients.val[j]!) + -- (d) chunks j+8 for k ≤ j < 8: unchanged. + ∧ (∀ j : Nat, k.val ≤ j → j < 8 → + acc.1.coefficients.val[j + 8]! = re.coefficients.val[j + 8]!)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (8#usize : Std.Usize).val ∧ iter'.«end» = 8#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv re iter'.start acc').holds + | .done y => (inv re 8#usize y).holds + +end Layer7FC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for layer 7. Given a valid loop state + `(acc, k)` with `k.val < 8`, butterflies chunks `(k, k+8)` with the + constant Mont-form zeta `-1600`, leaving other chunks unchanged. -/ +theorem ntt_at_layer_7_step_lemma_fc + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 20) + (acc : Layer7FC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (8#usize : Std.Usize).val) + (h_inv : (Layer7FC.inv re k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + (vectortraitsOperationsInst := portable_ops_inst) 8#usize + { start := k, «end» := 8#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer7FC.step_post re k r ⌝ ⦄ := by + have h8 : (8#usize : Std.Usize).val = 8 := rfl + have h_coef_len : acc.1.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_acc_done_a, h_acc_done_b, h_acc_undone_a, h_acc_undone_b⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + by_cases h_lt : k.val < (8#usize : Std.Usize).val + · -- `Some j = k` branch. + have hk_8 : k.val < 8 := by rw [h8] at h_lt; exact h_lt + have hk_lt_16 : k.val < 16 := by omega + have hk8_lt_16 : k.val + 8 < 16 := by omega + obtain ⟨s, hs_val, h_iter_some⟩ := Layer7FC.iter_next8_some_eq k h_lt + -- (1) `j + 8 = i`. Bound: k.val + 8 ≤ 15. + have h_um8 : (8#usize : Std.Usize).val = 8 := rfl + have h_i_max : k.val + (8#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um8]; scalar_tac + obtain ⟨i_idx, h_i_eq, h_i_val⟩ := + Layer7FC.usize_add_ok_eq k 8#usize h_i_max + have h_i_val_arith : i_idx.val = k.val + 8 := by rw [h_i_val, h_um8] + -- (2) `index_usize re.coefficients i` → `scratch1 = acc.1.coefs[k+8]`. + set scratch1 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.1.coefficients.val[i_idx.val]! with hscratch1_def + have h_idx_i : Aeneas.Std.Array.index_usize acc.1.coefficients i_idx + = .ok scratch1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.1.coefficients i_idx + (by rw [h_coef_len, h_i_val_arith]; exact hk8_lt_16) + -- The actual scratch1 equals re[k+8] (via h_acc_undone_b at j=k). + have h_scratch1_eq : scratch1 = re.coefficients.val[k.val + 8]! := by + show acc.1.coefficients.val[i_idx.val]! = re.coefficients.val[k.val + 8]! + rw [h_i_val_arith] + exact h_acc_undone_b k.val (Nat.le_refl _) hk_8 + -- (3) `multiply_by_constant scratch1 (-1600)#i16`. + -- Use BOTH the FC lift and the legacy bound. + have h_s1_bnd : ∀ ℓ : Nat, ℓ < 16 → + (scratch1.elements.val[ℓ]!).val.natAbs ≤ 20 := by + intro ℓ hℓ + rw [h_scratch1_eq] + exact h_pre (k.val + 8) hk8_lt_16 ℓ hℓ + have h_s1_bnd_32767 : ∀ ℓ : Nat, ℓ < 16 → + (scratch1.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + have := h_s1_bnd ℓ hℓ; omega + have h_cm1600_bnd : ((-1600)#i16 : Std.I16).val.natAbs ≤ 1664 := by decide + have h_prod_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((scratch1.elements.val[ℓ]!).val * ((-1600)#i16 : Std.I16).val : Int).natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb := h_s1_bnd ℓ hℓ + have h_cm : ((-1600)#i16 : Std.I16).val = -1600 := by decide + have h_abs : ((scratch1.elements.val[ℓ]!).val + * ((-1600)#i16 : Std.I16).val : Int).natAbs + ≤ ((scratch1.elements.val[ℓ]!).val : Int).natAbs * 1600 := by + rw [h_cm] + rw [Int.natAbs_mul] + simp [Int.natAbs_neg] + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2] + have h_step : ((scratch1.elements.val[ℓ]!).val : Int).natAbs * 1600 ≤ 20 * 1600 := + Nat.mul_le_mul_right _ hb + omega + obtain ⟨scratch2, h_s2_eq, h_s2_lift⟩ := + triple_exists_ok_fc (multiply_by_constant_fc scratch1 (-1600)#i16 + h_s1_bnd_32767 h_cm1600_bnd h_prod_bnd) + -- Also extract the per-elem value-and-bound via legacy `_spec`. + have h_s2_spec := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.multiply_by_constant_spec scratch1 (-1600)#i16 h_prod_bnd + obtain ⟨scratch2', h_s2_eq', h_s2_per⟩ := triple_exists_ok_fc h_s2_spec + have h_s2_same : scratch2 = scratch2' := by + have := h_s2_eq.symm.trans h_s2_eq' + cases this; rfl + subst h_s2_same + -- (4) `index_usize re.coefficients j` → `t = acc.1.coefs[k]`. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.1.coefficients.val[k.val]! with ht_def + have h_idx_j : Aeneas.Std.Array.index_usize acc.1.coefficients k + = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.1.coefficients k + (by rw [h_coef_len]; exact hk_lt_16) + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.1.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone_a k.val (Nat.le_refl _) hk_8 + have h_t_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 20 := by + intro ℓ hℓ + rw [h_t_eq] + exact h_pre k.val hk_lt_16 ℓ hℓ + -- (5) `Array.update acc.1.coefficients i t` → `a = acc.1.coefs.set i t`. + set a : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.1.coefficients.set i_idx t with ha_def + have h_upd_a : Aeneas.Std.Array.update acc.1.coefficients i_idx t + = .ok a := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq acc.1.coefficients i_idx t + (by rw [h_coef_len, h_i_val_arith]; exact hk8_lt_16) + -- (6) `index_mut_usize a j` → `(t1, set_back) = (a.val[k]!, a.set k)`. + have h_a_len : a.length = 16 := by + simp [ha_def, h_coef_len] + -- Need to know a.val[k.val]! = acc.1.coefficients.val[k.val]! = t (k ≠ i.val). + have hki_ne : k.val ≠ i_idx.val := by rw [h_i_val_arith]; omega + have h_a_k : a.val[k.val]! = acc.1.coefficients.val[k.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i_idx k.val t + (fun h => hki_ne h.symm) + have h_imt_j : Aeneas.Std.Array.index_mut_usize a k + = .ok (t, a.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + have h_idx : Aeneas.Std.Array.index_usize a k = .ok (a.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a k (by rw [h_a_len]; exact hk_lt_16) + rw [h_idx] + have h_aval_eq : a.val[k.val]! = t := by rw [h_a_k] + rw [h_aval_eq]; rfl + -- (7) `vec.add t scratch2`. Pre: |t[ℓ] + scratch2[ℓ]| ≤ 32767. + have h_s2_bnd : ∀ ℓ : Nat, ℓ < 16 → + (scratch2.elements.val[ℓ]!).val.natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ; exact (h_s2_per ℓ hℓ).2 + have h_s2_val : ∀ ℓ : Nat, ℓ < 16 → + (scratch2.elements.val[ℓ]!).val + = (scratch1.elements.val[ℓ]!).val * ((-1600)#i16 : Std.I16).val := by + intro ℓ hℓ; exact (h_s2_per ℓ hℓ).1 + have h_t_s2_add_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((t.elements.val[ℓ]!).val + (scratch2.elements.val[ℓ]!).val : Int).natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb_t := h_t_bnd ℓ hℓ + rw [h_s2_val ℓ hℓ] + have h_s1_b := h_s1_bnd ℓ hℓ + have h_cm : ((-1600)#i16 : Std.I16).val = -1600 := by decide + rw [h_cm] + have h_prod_le : ((scratch1.elements.val[ℓ]!).val * (-1600) : Int).natAbs + ≤ ((scratch1.elements.val[ℓ]!).val : Int).natAbs * 1600 := by + rw [Int.natAbs_mul]; simp [Int.natAbs_neg] + have h_prod_b : ((scratch1.elements.val[ℓ]!).val * (-1600) : Int).natAbs ≤ 32000 := by + have := Nat.mul_le_mul_right 1600 h_s1_b + omega + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2] + -- |a + b| ≤ |a| + |b| ≤ 20 + 32000 = 32020. + have h_abs_add : ((t.elements.val[ℓ]!).val + + (scratch1.elements.val[ℓ]!).val * (-1600) : Int).natAbs + ≤ ((t.elements.val[ℓ]!).val : Int).natAbs + + ((scratch1.elements.val[ℓ]!).val * (-1600) : Int).natAbs := + Int.natAbs_add_le _ _ + omega + obtain ⟨t2, h_t2_eq, h_t2_lift⟩ := + triple_exists_ok_fc (add_fc t scratch2 h_t_s2_add_bnd) + -- (8) `set_back t2` = `a.set k t2`. + set a1 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a.set k t2 with ha1_def + -- (9) `index_mut_usize a1 i`. Need a1.val[i.val]!. a1 = (acc.1.coefs.set i t).set k t2. + -- Since k ≠ i, a1.val[i.val]! = (acc.1.coefs.set i t).val[i.val]! = t. + have h_a1_len : a1.length = 16 := by + simp [ha1_def, h_a_len] + have h_a1_i : a1.val[i_idx.val]! = t := by + have h_ne : k.val ≠ i_idx.val := hki_ne + have h_step1 : a1.val[i_idx.val]! = a.val[i_idx.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a k i_idx.val t2 h_ne + have h_step2 : a.val[i_idx.val]! = t := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.1.coefficients i_idx i_idx.val t + ⟨rfl, by rw [h_coef_len, h_i_val_arith]; exact hk8_lt_16⟩ + rw [h_step1, h_step2] + have h_imt_i : Aeneas.Std.Array.index_mut_usize a1 i_idx + = .ok (t, a1.set i_idx) := by + unfold Aeneas.Std.Array.index_mut_usize + have h_idx : Aeneas.Std.Array.index_usize a1 i_idx = .ok (a1.val[i_idx.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a1 i_idx + (by rw [h_a1_len, h_i_val_arith]; exact hk8_lt_16) + rw [h_idx, h_a1_i]; rfl + -- (10) `vec.sub t scratch2`. Pre: |t[ℓ] - scratch2[ℓ]| ≤ 32767. + have h_t_s2_sub_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((t.elements.val[ℓ]!).val - (scratch2.elements.val[ℓ]!).val : Int).natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb_t := h_t_bnd ℓ hℓ + rw [h_s2_val ℓ hℓ] + have h_s1_b := h_s1_bnd ℓ hℓ + have h_cm : ((-1600)#i16 : Std.I16).val = -1600 := by decide + rw [h_cm] + have h_prod_b : ((scratch1.elements.val[ℓ]!).val * (-1600) : Int).natAbs ≤ 32000 := by + have h_prod_le : ((scratch1.elements.val[ℓ]!).val * (-1600) : Int).natAbs + = ((scratch1.elements.val[ℓ]!).val : Int).natAbs * 1600 := by + rw [Int.natAbs_mul]; simp [Int.natAbs_neg] + have := Nat.mul_le_mul_right 1600 h_s1_b + omega + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2] + have h_abs_sub : ((t.elements.val[ℓ]!).val + - (scratch1.elements.val[ℓ]!).val * (-1600) : Int).natAbs + ≤ ((t.elements.val[ℓ]!).val : Int).natAbs + + ((scratch1.elements.val[ℓ]!).val * (-1600) : Int).natAbs := by + have := Int.natAbs_sub_le (t.elements.val[ℓ]!).val + ((scratch1.elements.val[ℓ]!).val * (-1600)) + exact this + omega + obtain ⟨t4, h_t4_eq, h_t4_lift⟩ := + triple_exists_ok_fc (sub_fc t scratch2 h_t_s2_sub_bnd) + -- Compose acc'. + set a2 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a1.set i_idx t4 with ha2_def + set acc' : Layer7FC.Acc := (({ coefficients := a2 } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector), + scratch2) with hacc'_def + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + (vectortraitsOperationsInst := portable_ops_inst) 8#usize + { start := k, «end» := 8#usize } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + -- Unfold the `@[reducible]` inst-forwards for multiply_by_constant, add, sub + -- to the arithmetic-form names used in `h_s2_eq`, `h_t2_eq`, `h_t4_eq`. + show (do + let i ← k + 8#usize + let scratch1' ← Aeneas.Std.Array.index_usize acc.1.coefficients i + let scratch2' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.multiply_by_constant + scratch1' (-1600)#i16 + let t' ← Aeneas.Std.Array.index_usize acc.1.coefficients k + let a' ← Aeneas.Std.Array.update acc.1.coefficients i t' + let (t1, index_mut_back) ← Aeneas.Std.Array.index_mut_usize a' k + let t2' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.add + t1 scratch2' + let (t3, index_mut_back1) ← Aeneas.Std.Array.index_mut_usize + (index_mut_back t2') i + let t4' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.sub + t3 scratch2' + .ok (ControlFlow.cont (({ start := s, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := index_mut_back1 t4' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector), + scratch2'))) = _ + rw [h_i_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_i]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_s2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_j]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_a]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_j]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_i]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rfl + apply triple_of_ok_fc h_body + show Layer7FC.step_post re k + (.cont (({ start := s, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer7FC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (Layer7FC.inv re s acc').holds + -- Now the invariant at (s, acc'). Note acc'.1.coefficients = a2 = (a.set k t2).set i_idx t4. + -- a.set k t2: at i=k.val it is t2, elsewhere it is a = acc.1.coefs.set i_idx t. + -- .set i_idx t4: at i=i_idx.val it is t4, elsewhere unchanged. + have h_inv_pure : + (∀ j : Nat, j < s.val → + lift_chunk (acc'.1.coefficients.val[j]!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re.coefficients.val[j]!)) + (lift_chunk (re.coefficients.val[j + 8]!)) + Spec.zeta_layer_7) + ∧ (∀ j : Nat, j < s.val → + lift_chunk (acc'.1.coefficients.val[j + 8]!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re.coefficients.val[j]!)) + (lift_chunk (re.coefficients.val[j + 8]!)) + Spec.zeta_layer_7) + ∧ (∀ j : Nat, s.val ≤ j → j < 8 → + acc'.1.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ j : Nat, s.val ≤ j → j < 8 → + acc'.1.coefficients.val[j + 8]! = re.coefficients.val[j + 8]!) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · -- (a) j < s.val → a-side butterfly. + intro j hj + rw [hs_val] at hj + show lift_chunk ((a1.set i_idx t4).val[j]!) = _ + -- j < k.val + 1 → j < k.val OR j = k.val. + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: unchanged in a2 (since j ≠ k.val ∧ j ≠ k.val+8). + have h_ne_i : i_idx.val ≠ j := by rw [h_i_val_arith]; omega + have h_ne_k : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_step1 : (a1.set i_idx t4).val[j]! = a1.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i_idx j t4 h_ne_i + have h_step2 : a1.val[j]! = a.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a k j t2 h_ne_k + have h_step3 : a.val[j]! = acc.1.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i_idx j t + (by rw [h_i_val_arith]; omega) + rw [h_step1, h_step2, h_step3] + exact h_acc_done_a j hj_lt_k + · -- j = k.val: a2[k.val] = t2 (because k.val ≠ i_idx.val). + subst hj_eq_k + have h_ne_i : i_idx.val ≠ k.val := fun h => hki_ne h.symm + have h_step1 : (a1.set i_idx t4).val[k.val]! = a1.val[k.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i_idx k.val t4 h_ne_i + have h_step2 : a1.val[k.val]! = t2 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq a k k.val t2 + ⟨rfl, by rw [h_a_len]; exact hk_lt_16⟩ + rw [h_step1, h_step2] + -- Goal: lift_chunk t2 = chunk_pair_butterfly_a_pure (lift_chunk re[k]) (lift_chunk re[k+8]) z_layer_7. + rw [h_t2_lift] + show Spec.chunk_add_pure (lift_chunk t) (lift_chunk scratch2) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re.coefficients.val[k.val]!)) + (lift_chunk (re.coefficients.val[k.val + 8]!)) + Spec.zeta_layer_7 + rw [h_s2_lift, ← h_scratch1_eq, ← h_t_eq] + -- Goal: chunk_add_pure (lift_chunk t) (chunk_mul (lift_chunk scratch1) (lift_fe -1600)) + -- = chunk_pair_butterfly_a_pure (lift_chunk t) (lift_chunk scratch1) Spec.zeta_layer_7. + unfold Spec.chunk_add_pure Spec.chunk_multiply_by_constant_pure + Spec.chunk_pair_butterfly_a_pure Spec.zeta_layer_7 + apply Subtype.ext + -- Goal now: .val of LHS = .val of RHS. The .val of Std.Array.make is the list. + change (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((lift_chunk t).val[i]!) + ((Std.Array.make 16#usize ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[i]!) (lift_fe ((-1600)#i16)))) (by simp)).val[i]!)) + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((lift_chunk t).val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[ℓ]!) (lift_fe ((-1600)#i16)))) + apply List.ext_getElem + · simp + · intro ℓ hℓ1 _hℓ2 + have hℓ : ℓ < 16 := by + have : ℓ < (List.range 16).length := by simpa using hℓ1 + simpa using this + rw [List.getElem_map, List.getElem_range, + List.getElem_map, List.getElem_range] + congr 1 + -- Goal: ((Std.Array.make 16 (range.map ...)).val[ℓ]!) = mul_pure ((lift_chunk scratch1).val[ℓ]!) (lift_fe ...). + show ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[i]!) (lift_fe ((-1600)#i16))))[ℓ]! = _ + have h_len : ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[i]!) (lift_fe ((-1600)#i16)))).length = 16 := by simp + have h_pos : ℓ < ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[i]!) (lift_fe ((-1600)#i16)))).length := by + rw [h_len]; exact hℓ + rw [getElem!_pos _ ℓ h_pos] + rw [List.getElem_map, List.getElem_range] + · -- (b) j < s.val → b-side butterfly at j+8. + intro j hj + rw [hs_val] at hj + show lift_chunk ((a1.set i_idx t4).val[j + 8]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: chunk j+8 unchanged in a2. + have h_jp8_ne_i : j + 8 ≠ i_idx.val := by rw [h_i_val_arith]; omega + have h_jp8_ne_k : j + 8 ≠ k.val := by omega + have h_step1 : (a1.set i_idx t4).val[j + 8]! = a1.val[j + 8]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i_idx (j + 8) t4 + (fun h => h_jp8_ne_i h.symm) + have h_step2 : a1.val[j + 8]! = a.val[j + 8]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a k (j + 8) t2 + (fun h => h_jp8_ne_k h.symm) + have h_step3 : a.val[j + 8]! = acc.1.coefficients.val[j + 8]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i_idx (j + 8) t + (by rw [h_i_val_arith]; omega) + rw [h_step1, h_step2, h_step3] + exact h_acc_done_b j hj_lt_k + · -- j = k.val: a2[k.val + 8] = t4 (because i_idx.val = k.val+8). + subst hj_eq_k + have h_i_eq_kp8 : i_idx.val = k.val + 8 := h_i_val_arith + have h_step1 : (a1.set i_idx t4).val[k.val + 8]! = t4 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq a1 i_idx (k.val + 8) t4 + ⟨h_i_eq_kp8, by rw [h_a1_len]; exact hk8_lt_16⟩ + rw [h_step1] + rw [h_t4_lift] + show Spec.chunk_sub_pure (lift_chunk t) (lift_chunk scratch2) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re.coefficients.val[k.val]!)) + (lift_chunk (re.coefficients.val[k.val + 8]!)) + Spec.zeta_layer_7 + rw [h_s2_lift, ← h_scratch1_eq, ← h_t_eq] + unfold Spec.chunk_sub_pure Spec.chunk_multiply_by_constant_pure + Spec.chunk_pair_butterfly_b_pure Spec.zeta_layer_7 + apply Subtype.ext + change (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((lift_chunk t).val[i]!) + ((Std.Array.make 16#usize ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[i]!) (lift_fe ((-1600)#i16)))) (by simp)).val[i]!)) + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((lift_chunk t).val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[ℓ]!) (lift_fe ((-1600)#i16)))) + apply List.ext_getElem + · simp + · intro ℓ hℓ1 _hℓ2 + have hℓ : ℓ < 16 := by + have : ℓ < (List.range 16).length := by simpa using hℓ1 + simpa using this + rw [List.getElem_map, List.getElem_range, + List.getElem_map, List.getElem_range] + congr 1 + -- Goal: ((Std.Array.make 16 (range.map ...)).val[ℓ]!) = mul_pure ((lift_chunk scratch1).val[ℓ]!) (lift_fe ...). + show ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[i]!) (lift_fe ((-1600)#i16))))[ℓ]! = _ + have h_len : ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[i]!) (lift_fe ((-1600)#i16)))).length = 16 := by simp + have h_pos : ℓ < ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk scratch1).val[i]!) (lift_fe ((-1600)#i16)))).length := by + rw [h_len]; exact hℓ + rw [getElem!_pos _ ℓ h_pos] + rw [List.getElem_map, List.getElem_range] + · -- (c) s.val ≤ j < 8 → acc'[j] = re[j]. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + show (a1.set i_idx t4).val[j]! = re.coefficients.val[j]! + have h_j_ne_i : i_idx.val ≠ j := by rw [h_i_val_arith]; omega + have h_j_ne_k : k.val ≠ j := by omega + have h_step1 : (a1.set i_idx t4).val[j]! = a1.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i_idx j t4 h_j_ne_i + have h_step2 : a1.val[j]! = a.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a k j t2 h_j_ne_k + have h_step3 : a.val[j]! = acc.1.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i_idx j t + (by rw [h_i_val_arith]; omega) + rw [h_step1, h_step2, h_step3] + have hk_le_j : k.val ≤ j := by omega + exact h_acc_undone_a j hk_le_j hj_lt + · -- (d) s.val ≤ j < 8 → acc'[j+8] = re[j+8]. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + show (a1.set i_idx t4).val[j + 8]! = re.coefficients.val[j + 8]! + have h_jp8_ne_i : j + 8 ≠ i_idx.val := by rw [h_i_val_arith]; omega + have h_jp8_ne_k : j + 8 ≠ k.val := by omega + have h_step1 : (a1.set i_idx t4).val[j + 8]! = a1.val[j + 8]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i_idx (j + 8) t4 + (fun h => h_jp8_ne_i h.symm) + have h_step2 : a1.val[j + 8]! = a.val[j + 8]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne a k (j + 8) t2 + (fun h => h_jp8_ne_k h.symm) + have h_step3 : a.val[j + 8]! = acc.1.coefficients.val[j + 8]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i_idx (j + 8) t + (by rw [h_i_val_arith]; omega) + rw [h_step1, h_step2, h_step3] + have hk_le_j : k.val ≤ j := by omega + exact h_acc_undone_b j hk_le_j hj_lt + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 8, done. + have hk_ge : k.val ≥ (8#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 8 := by rw [h8] at hk_ge; omega + have h_iter_none := Layer7FC.iter_next8_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + (vectortraitsOperationsInst := portable_ops_inst) 8#usize + { start := k, «end» := 8#usize } acc.1 acc.2 + = .ok (ControlFlow.done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer7FC.step_post re k (.done acc) + unfold Layer7FC.step_post + show (Layer7FC.inv re 8#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (8#usize : Std.Usize).val → + lift_chunk (acc.1.coefficients.val[j]!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re.coefficients.val[j]!)) + (lift_chunk (re.coefficients.val[j + 8]!)) + Spec.zeta_layer_7) + ∧ (∀ j : Nat, j < (8#usize : Std.Usize).val → + lift_chunk (acc.1.coefficients.val[j + 8]!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re.coefficients.val[j]!)) + (lift_chunk (re.coefficients.val[j + 8]!)) + Spec.zeta_layer_7) + ∧ (∀ j : Nat, (8#usize : Std.Usize).val ≤ j → j < 8 → + acc.1.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ j : Nat, (8#usize : Std.Usize).val ≤ j → j < 8 → + acc.1.coefficients.val[j + 8]! = re.coefficients.val[j + 8]!) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · intro j hj; rw [h8] at hj + apply h_acc_done_a j; rw [hk_eq]; exact hj + · intro j hj; rw [h8] at hj + apply h_acc_done_b j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h8] at hj_ge + apply h_acc_undone_a j _ hj_lt; rw [hk_eq]; exact hj_ge + · intro j hj_ge hj_lt + rw [h8] at hj_ge + apply h_acc_undone_b j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L3.7' — `ntt_at_layer_7` PortableVector-specialised FC equation. + Single chunk-stride-8 butterfly layer with constant zeta -1600. + + The impl iterates `j ∈ 0..8` and butterflies chunks `(j, j+8)` with + constant Mont-form zeta `-1600`; the lifted constant is + `Spec.zeta_layer_7 = lift_fe (-1600)#i16`. + + **Preconditions** (load-bearing): + - `h_bnd` : per-lane input bound 20 across all 16 chunks × 16 lanes + — strengthened from the original 29439 to admit + `multiply_by_constant_fc`'s `hpre_prod` obligation with the + constant `c = -1600`: `|20 * 1600| = 32000 ≤ 2^15 - 1 = 32767`. + Callers are `ntt_binomially_sampled_ring_element` (binomial-sampled + input range `{-3..3}`) and equivalents — comfortably below 20. -/ +@[spec high] +theorem ntt_at_layer_7_portable_fc + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 20) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_7 + (vectortraitsOperationsInst := portable_ops_inst) re scratch + ⦃ ⇓ p => ⌜ lift_poly p.1 = Spec.ntt_at_layer_7_pure (lift_poly re) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7 + -- The driver: `i ← VECTORS_IN_RING_ELEMENT (=16), step ← i/2 (=8); loop on (0..8)`. + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + have h_div : ((16#usize : Std.Usize) / (2#usize : Std.Usize) : Result Std.Usize) + = .ok (8#usize : Std.Usize) := by + have h_max : ((2#usize : Std.Usize).val : Nat) ≠ 0 := by decide + obtain ⟨z, hz_eq, hz_v⟩ := Aeneas.Std.UScalar.div_spec (16#usize : Std.Usize) h_max + have hz_val : (↑z : Nat) = 8 := by + rw [hz_v]; decide + have hz_eq8 : z = (8#usize : Std.Usize) := by + apply Aeneas.Std.UScalar.eq_of_val_eq + show z.val = (8#usize : Std.Usize).val + rw [hz_val]; decide + rw [hz_eq, hz_eq8] + rw [h_vre] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_div] + simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + (vectortraitsOperationsInst := portable_ops_inst) 8#usize + iter1 acc1.1 acc1.2) + (β := Layer7FC.Acc) + (re, scratch) + 0#usize 8#usize + (Layer7FC.inv re) + (by decide : (0#usize : Std.Usize).val ≤ (8#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_, ?_⟩ + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _; trivial + · intro _ _ _; trivial) + ?_) + · -- Post entailment: at k=8, the invariant gives all FC equations. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Layer7FC.inv re 8#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j : Nat, j < (8#usize : Std.Usize).val → + lift_chunk (r.1.coefficients.val[j]!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re.coefficients.val[j]!)) + (lift_chunk (re.coefficients.val[j + 8]!)) + Spec.zeta_layer_7) + ∧ (∀ j : Nat, j < (8#usize : Std.Usize).val → + lift_chunk (r.1.coefficients.val[j + 8]!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re.coefficients.val[j]!)) + (lift_chunk (re.coefficients.val[j + 8]!)) + Spec.zeta_layer_7) + ∧ (∀ j : Nat, (8#usize : Std.Usize).val ≤ j → j < 8 → + r.1.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ j : Nat, (8#usize : Std.Usize).val ≤ j → j < 8 → + r.1.coefficients.val[j + 8]! = re.coefficients.val[j + 8]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer7FC.inv] using h_inv_holds + obtain ⟨h_done_a, h_done_b, _h_undone_a, _h_undone_b⟩ := h_inv + have h8 : (8#usize : Std.Usize).val = 8 := rfl + unfold Spec.ntt_at_layer_7_pure + -- Build chunks_arr matching the Spec layout: per chunk c ∈ 0..16: + -- chunk_at_layer_4_plus_pure chunks0 7#usize (fun _ => zeta_layer_7) c. + -- For layer 7: step_vec = 128/16 = 8; group = c/16 = 0; offset = c%16 = c. + -- c < 8: a-side with partner c+8. + -- c ≥ 8: b-side, partner c-8 (i.e. chunk = chunks0[c-8] - chunks0[c] * z). + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun c => + Spec.chunk_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at (lift_poly re))) + (by simp)) + 7#usize + (fun _ => Spec.zeta_layer_7) + c)) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16; simp + have h_chunks_get : ∀ c : Nat, (hc : c < 16) → + chunks_arr.val[c]'(by rw [h_chunks_len]; exact hc) + = lift_chunk (r.1.coefficients.val[c]!) := by + intro c hc + show ((List.range 16).map (fun c => + Spec.chunk_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at (lift_poly re))) + (by simp)) + 7#usize + (fun _ => Spec.zeta_layer_7) + c))[c]'_ = _ + rw [List.getElem_map, List.getElem_range] + -- Unfold Spec.chunk_at_layer_4_plus_pure. + unfold Spec.chunk_at_layer_4_plus_pure + have h7 : (7#usize : Std.Usize).val = 7 := rfl + have h_sv : (1 <<< (7#usize : Std.Usize).val) / 16 = 8 := by rw [h7]; decide + have h_off : c % (2 * 8) = c := Nat.mod_eq_of_lt (by omega) + simp only [h_sv, h_off] + -- Now: if c < 8 then a-side else b-side. + -- Inner chunks0 lookup at index c, c+8, c-8 reduces via chunk_at_lift_poly_fc. + have h_chunks0_at : ∀ k : Nat, k < 16 → + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at (lift_poly re))) + (by simp)).val[k]! + = lift_chunk (re.coefficients.val[k]!) := by + intro k hk + have h_len_map : ((List.range 16).map (Spec.chunk_at (lift_poly re))).length = 16 := by simp + show ((List.range 16).map (Spec.chunk_at (lift_poly re)))[k]! = _ + rw [getElem!_pos _ k (by rw [h_len_map]; exact hk)] + rw [List.getElem_map, List.getElem_range] + exact chunk_at_lift_poly_fc re k hk + by_cases h_c_lt8 : c < 8 + · -- a-side: c < 8, partner c+8. + simp only [if_pos h_c_lt8] + rw [h_chunks0_at c (by omega), h_chunks0_at (c + 8) (by omega)] + exact (h_done_a c h_c_lt8).symm + · -- b-side: c ≥ 8, partner c-8. + simp only [if_neg h_c_lt8] + have h_c_lt_16 : c < 16 := hc + have h_cm8 : c - 8 < 16 := by omega + rw [h_chunks0_at (c - 8) h_cm8, h_chunks0_at c h_c_lt_16] + -- Goal: chunk_pair_butterfly_b_pure (lift_chunk re[c-8]) (lift_chunk re[c]) Spec.zeta_layer_7 + -- = lift_chunk r.1.coefs[c] + -- We have h_done_b (c-8) : lift_chunk r.1.coefs[(c-8) + 8] + -- = chunk_pair_butterfly_b_pure (lift_chunk re[c-8]) (lift_chunk re[c-8+8]) z_layer_7. + have h_cm8_lt8 : c - 8 < 8 := by omega + have h_simp : c - 8 + 8 = c := by omega + have := h_done_b (c - 8) h_cm8_lt8 + rw [h_simp] at this + exact this.symm + -- Apply flatten_chunks_eq_lift_poly_fc. + have h_final := flatten_chunks_eq_lift_poly_fc r.1 chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Step lemma dispatch. + intro acc k _h_ge h_le hinv + have h_step := ntt_at_layer_7_step_lemma_fc re h_bnd acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer7FC.step_post re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer7FC.step_post] using hP + · have hP : Layer7FC.step_post re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer7FC.step_post] using hP + +/-! ### L3.4_plus' — Helper namespace and lemmas. + + Layer 4-6 NTT driver: nested loop applying a single chunk-pair + butterfly per inner iter at chunks `(a_offset+j, b_offset+j)` with + one zeta per outer round. The keystone `ntt_layer_int_vec_step_fc` + encapsulates the impl's `ntt.ntt_layer_int_vec_step` body, which + performs `(coefs[a], coefs[b]) := (coefs[a] + scratch2, coefs[a] - scratch2)` + where `scratch2 := mont_mult(coefs[b], zeta_r)`. -/ + +namespace Layer4PlusFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Local `usize_add_ok_eq` helper. -/ +theorem usize_add_ok_eq (x y : Std.Usize) + (h_max : x.val + y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x + y : Result Std.Usize) = .ok z ∧ z.val = x.val + y.val := by + have hT := Std.Usize.add_spec h_max + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v⟩ + +/-- Local `usize_mul_ok_eq` helper. -/ +theorem usize_mul_ok_eq (x y : Std.Usize) + (h_max : x.val * y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x * y : Result Std.Usize) = .ok z ∧ z.val = x.val * y.val := by + have hT := Std.Usize.mul_spec h_max + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v⟩ + +/-- Generic `IteratorRange.next` Some branch for arbitrary end. -/ +theorem iter_next_some_eq_gen (i e : Std.Usize) + (h_lt : i.val < e.val) : + ∃ s : Std.Usize, s.val = i.val + 1 ∧ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := e } : CoreModels.core.ops.range.Range Std.Usize) + = .ok (some i, + ({ start := s, «end» := e } : CoreModels.core.ops.range.Range Std.Usize)) := by + have hT := libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize i e + (Q := PostCond.noThrow fun (oi : Option Std.Usize × _) => ⌜ + ∃ s : Std.Usize, s.val = i.val + 1 + ∧ oi = (some i, + ({ start := s, «end» := e } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝) + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + have hex : ∃ v, core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := e } : CoreModels.core.ops.range.Range Std.Usize) = .ok v + ∧ (∃ s : Std.Usize, s.val = i.val + 1 + ∧ v = (some i, + ({ start := s, «end» := e } + : CoreModels.core.ops.range.Range Std.Usize))) := by + generalize hx : core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := e } : CoreModels.core.ops.range.Range Std.Usize) = X at hT + match X, hT with + | .ok v, hT => exact ⟨v, rfl, by simpa [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply] using hT⟩ + | .fail _, hT => exact absurd hT (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div, hT => exact absurd hT (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + obtain ⟨v, hveq, s, hs_val, hpair⟩ := hex + exact ⟨s, hs_val, by rw [hveq, hpair]⟩ + +/-- Generic `IteratorRange.next` None branch for arbitrary end. -/ +theorem iter_next_none_eq_gen (i e : Std.Usize) + (h_ge : i.val ≥ e.val) : + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := e } : CoreModels.core.ops.range.Range Std.Usize) + = .ok ((none : Option Std.Usize), + ({ start := i, «end» := e } + : CoreModels.core.ops.range.Range Std.Usize)) := by + have hT := libcrux_iot_ml_kem.Util.LoopSpecs.IteratorRange_next_spec_usize i e + (Q := PostCond.noThrow fun (oi : Option Std.Usize × _) => ⌜ + oi = ((none : Option Std.Usize), + ({ start := i, «end» := e } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝) + (fun hlt => absurd hlt (Nat.not_lt.mpr h_ge)) + (fun _ => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + generalize hx : core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := e } : CoreModels.core.ops.range.Range Std.Usize) = X at hT + match X, hT with + | .ok v, hT => + have hP : v = ((none : Option Std.Usize), + ({ start := i, «end» := e } + : CoreModels.core.ops.range.Range Std.Usize)) := by + simpa [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply] using hT + rw [hP] + | .fail _, hT => exact absurd hT (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div, hT => exact absurd hT (by simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply]) + +end Layer4PlusFC + +set_option maxHeartbeats 16000000 in +/-- L3.4_plus' keystone: full FC equation for `ntt.ntt_layer_int_vec_step` + at the chunk-pair level. Performs the impl's compound body: + `scratch2 = mont_mult(coefs[b], zeta_r)`, + `coefs[a] := coefs[a] + scratch2`, + `coefs[b] := coefs[a]_orig - scratch2`. -/ +theorem ntt_layer_int_vec_step_fc + (coefficients : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize) + (a b : Std.Usize) (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_r : Std.I16) + (ha : a.val < 16) (hb : b.val < 16) (hab : a.val ≠ b.val) + (hzeta : zeta_r.val.natAbs ≤ 1664) + (h_bnd_a : ∀ ℓ : Nat, ℓ < 16 → + ((coefficients.val[a.val]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (h_bnd_b : ∀ ℓ : Nat, ℓ < 16 → + ((coefficients.val[b.val]!).elements.val[ℓ]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_layer_int_vec_step + (vectortraitsOperationsInst := portable_ops_inst) + coefficients a b scratch zeta_r + ⦃ ⇓ r => ⌜ + lift_chunk (r.1.val[a.val]!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (coefficients.val[a.val]!)) + (lift_chunk (coefficients.val[b.val]!)) + (lift_fe_mont zeta_r) + ∧ lift_chunk (r.1.val[b.val]!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (coefficients.val[a.val]!)) + (lift_chunk (coefficients.val[b.val]!)) + (lift_fe_mont zeta_r) + ∧ (∀ k : Nat, k < 16 → k ≠ a.val → k ≠ b.val → + r.1.val[k]! = coefficients.val[k]!) + ⌝ ⦄ := by + have h_coef_len : coefficients.length = 16 := Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.ntt.ntt_layer_int_vec_step + -- (1) index_usize coefficients b → scratch1 = coefs[b] + set scratch1 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + coefficients.val[b.val]! with hscratch1_def + have h_idx_b : Aeneas.Std.Array.index_usize coefficients b = .ok scratch1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq coefficients b + (by rw [h_coef_len]; exact hb) + -- (2) montgomery_multiply_fe scratch1 zeta_r → scratch2 via inst forwarder. + -- The inst forwarder does: classify zeta_r ; arithmetic.montgomery_multiply_by_constant. + have h_classify_zeta : libcrux_secrets.traits.Classify.Blanket.classify zeta_r = .ok zeta_r := + ntt_step_fc.classify_ok_eq zeta_r + have h_scratch1_bnd : ∀ ℓ : Nat, ℓ < 16 → + (scratch1.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + have h_v : (scratch1.elements.val[ℓ]!).val.natAbs + = ((coefficients.val[b.val]!).elements.val[ℓ]!).val.natAbs := by + rw [hscratch1_def] + rw [h_v] + have := h_bnd_b ℓ hℓ + omega + -- Per-element legacy spec to get the modq form. + have h_s2_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.montgomery_multiply_by_constant_spec scratch1 zeta_r hzeta + obtain ⟨scratch2, h_s2_eq, h_s2_per⟩ := triple_exists_ok_fc h_s2_legacy + -- Bound the output: |scratch2| ≤ 3328. + have h_s2_bnd_3328 : ∀ ℓ : Nat, ℓ < 16 → + (scratch2.elements.val[ℓ]!).val.natAbs ≤ 3328 := fun ℓ hℓ => (h_s2_per ℓ hℓ).1 + -- (3) index_usize coefficients a → t = coefs[a] + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + coefficients.val[a.val]! with ht_def + have h_idx_a : Aeneas.Std.Array.index_usize coefficients a = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq coefficients a + (by rw [h_coef_len]; exact ha) + have h_t_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 29439 := fun ℓ hℓ => h_bnd_a ℓ hℓ + -- (4) update coefficients b t → c1 = coefficients.set b t + set c1 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + coefficients.set b t with hc1_def + have h_upd_b : Aeneas.Std.Array.update coefficients b t = .ok c1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq coefficients b t + (by rw [h_coef_len]; exact hb) + have h_c1_len : c1.length = 16 := by simp [hc1_def, h_coef_len] + -- (5) index_mut_usize c1 a → (t1, set_back) = (c1[a], c1.set a). + have h_c1_a : c1.val[a.val]! = t := by + show (coefficients.set b t).val[a.val]! = t + -- a ≠ b, so c1[a] = coefficients[a] = t. + have h_ne : b.val ≠ a.val := fun h => hab h.symm + have h_step : (coefficients.set b t).val[a.val]! = coefficients.val[a.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne coefficients b a.val t h_ne + rw [h_step] + have h_imt_a : Aeneas.Std.Array.index_mut_usize c1 a = .ok (t, c1.set a) := by + unfold Aeneas.Std.Array.index_mut_usize + have h_idx : Aeneas.Std.Array.index_usize c1 a = .ok (c1.val[a.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq c1 a (by rw [h_c1_len]; exact ha) + rw [h_idx, h_c1_a]; rfl + -- (6) add t scratch2 → t2 = t + scratch2 (chunk-level). Pre: |t[ℓ] + scratch2[ℓ]| ≤ 32767. + have h_t_s2_add_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((t.elements.val[ℓ]!).val + (scratch2.elements.val[ℓ]!).val : Int).natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have hbt := h_t_bnd ℓ hℓ + have hbs2 := h_s2_bnd_3328 ℓ hℓ + have h_abs_add : ((t.elements.val[ℓ]!).val + (scratch2.elements.val[ℓ]!).val : Int).natAbs + ≤ ((t.elements.val[ℓ]!).val : Int).natAbs + + ((scratch2.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2]; omega + obtain ⟨t2, h_t2_eq, h_t2_lift⟩ := + triple_exists_ok_fc (add_fc t scratch2 h_t_s2_add_bnd) + -- For per-element value-form (needed downstream): use legacy add_spec. + have h_t2_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.add_spec t scratch2 h_t_s2_add_bnd + obtain ⟨t2', h_t2_eq', h_t2_per⟩ := triple_exists_ok_fc h_t2_legacy + have h_t2_same : t2 = t2' := by + have := h_t2_eq.symm.trans h_t2_eq'; cases this; rfl + subst h_t2_same + -- (7) set_back t2 = c1.set a t2 → c2. + set c2 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + c1.set a t2 with hc2_def + have h_c2_len : c2.length = 16 := by simp [hc2_def, h_c1_len] + -- (8) index_mut_usize c2 b → (t3, set_back') = (c2[b], c2.set b). + -- c2 = (coefficients.set b t).set a t2. At index b, since a ≠ b, c2[b] = c1[b] = t. + have h_c2_b : c2.val[b.val]! = t := by + show (c1.set a t2).val[b.val]! = t + have h_ne : a.val ≠ b.val := hab + have h_step1 : (c1.set a t2).val[b.val]! = c1.val[b.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne c1 a b.val t2 h_ne + have h_step2 : c1.val[b.val]! = t := by + show (coefficients.set b t).val[b.val]! = t + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq coefficients b b.val t + ⟨rfl, by rw [h_coef_len]; exact hb⟩ + rw [h_step1, h_step2] + have h_imt_b : Aeneas.Std.Array.index_mut_usize c2 b = .ok (t, c2.set b) := by + unfold Aeneas.Std.Array.index_mut_usize + have h_idx : Aeneas.Std.Array.index_usize c2 b = .ok (c2.val[b.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq c2 b (by rw [h_c2_len]; exact hb) + rw [h_idx, h_c2_b]; rfl + -- (9) sub t scratch2 → t4 = t - scratch2 (chunk-level). Pre similar. + have h_t_s2_sub_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((t.elements.val[ℓ]!).val - (scratch2.elements.val[ℓ]!).val : Int).natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have hbt := h_t_bnd ℓ hℓ + have hbs2 := h_s2_bnd_3328 ℓ hℓ + have h_abs_sub : + ((t.elements.val[ℓ]!).val - (scratch2.elements.val[ℓ]!).val : Int).natAbs + ≤ ((t.elements.val[ℓ]!).val : Int).natAbs + + ((scratch2.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_sub_le _ _ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2]; omega + obtain ⟨t4, h_t4_eq, h_t4_lift⟩ := + triple_exists_ok_fc (sub_fc t scratch2 h_t_s2_sub_bnd) + have h_t4_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.sub_spec t scratch2 h_t_s2_sub_bnd + obtain ⟨t4', h_t4_eq', h_t4_per⟩ := triple_exists_ok_fc h_t4_legacy + have h_t4_same : t4 = t4' := by + have := h_t4_eq.symm.trans h_t4_eq'; cases this; rfl + subst h_t4_same + -- (10) set_back' t4 = c2.set b t4 → c3 (final). + set c3 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + c2.set b t4 with hc3_def + -- Compose body equality. + have h_body : + libcrux_iot_ml_kem.ntt.ntt_layer_int_vec_step + (vectortraitsOperationsInst := portable_ops_inst) + coefficients a b scratch zeta_r + = .ok (c3, scratch2) := by + unfold libcrux_iot_ml_kem.ntt.ntt_layer_int_vec_step + show (do + let scratch1' ← Aeneas.Std.Array.index_usize coefficients b + let scratch2' ← + libcrux_iot_ml_kem.vector.traits.montgomery_multiply_fe + portable_ops_inst scratch1' zeta_r + let t' ← Aeneas.Std.Array.index_usize coefficients a + let c1' ← Aeneas.Std.Array.update coefficients b t' + let (t1, index_mut_back) ← Aeneas.Std.Array.index_mut_usize c1' a + let t2' ← portable_ops_inst.add t1 scratch2' + let (t3, index_mut_back1) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back t2') b + let t4' ← portable_ops_inst.sub t3 scratch2' + .ok (index_mut_back1 t4', scratch2')) = _ + -- The instance methods reduce to vector.portable.arithmetic.{add,sub,montgomery_multiply_by_constant}. + show (do + let scratch1' ← Aeneas.Std.Array.index_usize coefficients b + let scratch2' ← do + let i ← libcrux_secrets.traits.Classify.Blanket.classify zeta_r + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant scratch1' i + let t' ← Aeneas.Std.Array.index_usize coefficients a + let c1' ← Aeneas.Std.Array.update coefficients b t' + let (t1, index_mut_back) ← Aeneas.Std.Array.index_mut_usize c1' a + let t2' ← libcrux_iot_ml_kem.vector.portable.arithmetic.add t1 scratch2' + let (t3, index_mut_back1) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back t2') b + let t4' ← libcrux_iot_ml_kem.vector.portable.arithmetic.sub t3 scratch2' + .ok (index_mut_back1 t4', scratch2')) = _ + rw [h_idx_b]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_classify_zeta]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_s2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_a]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_b]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_a]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_b]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rfl + apply triple_of_ok_fc h_body + -- Now prove the post FC equations. + refine ⟨?_, ?_, ?_⟩ + · -- (a) c3[a] = c2[a] (since c3 = c2.set b t4, and a ≠ b). + -- c2[a] = t2 (by set_eq with set a t2). + -- lift_chunk t2 = chunk_add_pure (lift_chunk t) (lift_chunk scratch2) by h_t2_lift. + -- Need: = chunk_pair_butterfly_a_pure (lift_chunk t) (lift_chunk scratch1) (lift_fe_mont zeta_r). + show lift_chunk (c3.val[a.val]!) = _ + have h_ne : b.val ≠ a.val := fun h => hab h.symm + have h_c3_a : c3.val[a.val]! = c2.val[a.val]! := by + show (c2.set b t4).val[a.val]! = c2.val[a.val]! + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne c2 b a.val t4 h_ne + have h_c2_a : c2.val[a.val]! = t2 := by + show (c1.set a t2).val[a.val]! = t2 + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq c1 a a.val t2 + ⟨rfl, by rw [h_c1_len]; exact ha⟩ + rw [h_c3_a, h_c2_a] + -- Goal: lift_chunk t2 = chunk_pair_butterfly_a_pure (lift_chunk t) (lift_chunk scratch1) (lift_fe_mont zeta_r). + -- We have h_t2_lift : lift_chunk t2 = chunk_add_pure (lift_chunk t) (lift_chunk scratch2). + -- Need to show: chunk_add_pure (lift_chunk t) (lift_chunk scratch2) + -- = chunk_pair_butterfly_a_pure (lift_chunk t) (lift_chunk scratch1) (lift_fe_mont zeta_r). + rw [h_t2_lift] + -- t = coefficients[a], scratch1 = coefficients[b]. + show Spec.chunk_add_pure (lift_chunk t) (lift_chunk scratch2) + = Spec.chunk_pair_butterfly_a_pure (lift_chunk t) (lift_chunk scratch1) (lift_fe_mont zeta_r) + -- Per-lane: lift_fe scratch2[ℓ] = mul_pure (lift_fe scratch1[ℓ]) (lift_fe_mont zeta_r). + unfold Spec.chunk_add_pure Spec.chunk_pair_butterfly_a_pure lift_chunk + apply Subtype.ext + show (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Std.Array.make 16#usize (t.elements.val.map lift_fe) (by simp)).val[i]!) + ((Std.Array.make 16#usize (scratch2.elements.val.map lift_fe) (by simp)).val[i]!)) + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Std.Array.make 16#usize (t.elements.val.map lift_fe) (by simp)).val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Std.Array.make 16#usize (scratch1.elements.val.map lift_fe) (by simp)).val[ℓ]!) + (lift_fe_mont zeta_r))) + apply List.ext_getElem + · simp + · intro ℓ hℓ1 _ + have hℓ : ℓ < 16 := by + have : ℓ < (List.range 16).length := hℓ1 + simpa using this + rw [List.getElem_map, List.getElem_range, + List.getElem_map, List.getElem_range] + congr 1 + -- Goal: (Std.Array.make 16 (s2.map lift_fe)).val[ℓ]! + -- = mul_pure ((Std.Array.make 16 (s1.map lift_fe)).val[ℓ]!) (lift_fe_mont zeta_r). + have h_s1_len : scratch1.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length scratch1 + have h_s2_len : scratch2.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length scratch2 + have h_s1_idx : (Std.Array.make 16#usize (scratch1.elements.val.map lift_fe) + (by simp)).val[ℓ]! = lift_fe (scratch1.elements.val[ℓ]!) := by + show (scratch1.elements.val.map lift_fe)[ℓ]! = _ + have h_lhs_len : (scratch1.elements.val.map lift_fe).length = 16 := by + simp [List.length_map, h_s1_len] + rw [getElem!_pos _ ℓ (by rw [h_lhs_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos scratch1.elements.val ℓ (by rw [h_s1_len]; exact hℓ)] + have h_s2_idx : (Std.Array.make 16#usize (scratch2.elements.val.map lift_fe) + (by simp)).val[ℓ]! = lift_fe (scratch2.elements.val[ℓ]!) := by + show (scratch2.elements.val.map lift_fe)[ℓ]! = _ + have h_lhs_len : (scratch2.elements.val.map lift_fe).length = 16 := by + simp [List.length_map, h_s2_len] + rw [getElem!_pos _ ℓ (by rw [h_lhs_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos scratch2.elements.val ℓ (by rw [h_s2_len]; exact hℓ)] + rw [h_s1_idx, h_s2_idx] + -- Now goal: lift_fe (scratch2.elements.val[ℓ]!) = mul_pure (lift_fe (scratch1.elements.val[ℓ]!)) (lift_fe_mont zeta_r). + -- Use lift_fe_mul_pure_mont_eq with the per-elem modq from h_s2_per. + have h_modq : libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + (scratch2.elements.val[ℓ]!).val + ((scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) 3329 := by + have h_per := (h_s2_per ℓ hℓ).2 + -- h_per : (scratch2.elements.val[ℓ]!).val * 2^16 % 3329 = (scratch1.elements.val[ℓ]!).val * zeta_r.val % 3329 + -- Need to convert to modq_eq form: r ≡ scratch1*zeta_r*169 mod q. + -- This follows from 2^16 * 169 ≡ 1 (mod 3329). + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + -- (a - b) % q = 0 ↔ a ≡ b mod q. + have h_169 : ((2^16 : Int) * 169) % 3329 = 1 := by decide + -- r * 2^16 ≡ s1 * z (mod q) + -- r * 2^16 * 169 ≡ s1 * z * 169 (mod q) + -- r * 1 ≡ s1 * z * 169 (mod q) (using 2^16 * 169 ≡ 1) + -- r ≡ s1 * z * 169. + have h_rmul : ((scratch2.elements.val[ℓ]!).val * (2^16 : Int) * 169) % 3329 + = ((scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 := by + have h1 : ((scratch2.elements.val[ℓ]!).val * (2^16 : Int) * 169) % 3329 + = ((scratch2.elements.val[ℓ]!).val * (2^16 : Int)) % 3329 * 169 % 3329 := by + rw [Int.mul_emod] + simp + have h2 : ((scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 + = ((scratch1.elements.val[ℓ]!).val * zeta_r.val) % 3329 * 169 % 3329 := by + rw [Int.mul_emod] + simp + rw [h1, h2, h_per] + have h_lhs : ((scratch2.elements.val[ℓ]!).val * (2^16 : Int) * 169) % 3329 + = (scratch2.elements.val[ℓ]!).val % 3329 := by + have h_mul_assoc : (scratch2.elements.val[ℓ]!).val * (2^16 : Int) * 169 + = (scratch2.elements.val[ℓ]!).val * ((2^16 : Int) * 169) := by ring + rw [h_mul_assoc] + rw [Int.mul_emod] + rw [h_169] + simp + have h_zsub : + ((scratch2.elements.val[ℓ]!).val + - (scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 = 0 := by + have h_sub_emod : ((scratch2.elements.val[ℓ]!).val + - (scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 + = ((scratch2.elements.val[ℓ]!).val % 3329 + - ((scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329) % 3329 := by + rw [Int.sub_emod] + rw [h_sub_emod] + rw [← h_lhs] + rw [h_rmul] + simp + exact h_zsub + exact lift_fe_mul_pure_mont_eq (scratch1.elements.val[ℓ]!) zeta_r + (scratch2.elements.val[ℓ]!) h_modq + · -- (b) c3[b] = c2.set b t4 [b] = t4. Lifts to chunk_pair_butterfly_b_pure. + show lift_chunk (c3.val[b.val]!) = _ + have h_c3_b : c3.val[b.val]! = t4 := by + show (c2.set b t4).val[b.val]! = t4 + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq c2 b b.val t4 + ⟨rfl, by rw [h_c2_len]; exact hb⟩ + rw [h_c3_b, h_t4_lift] + show Spec.chunk_sub_pure (lift_chunk t) (lift_chunk scratch2) + = Spec.chunk_pair_butterfly_b_pure (lift_chunk t) (lift_chunk scratch1) (lift_fe_mont zeta_r) + unfold Spec.chunk_sub_pure Spec.chunk_pair_butterfly_b_pure lift_chunk + apply Subtype.ext + show (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((Std.Array.make 16#usize (t.elements.val.map lift_fe) (by simp)).val[i]!) + ((Std.Array.make 16#usize (scratch2.elements.val.map lift_fe) (by simp)).val[i]!)) + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((Std.Array.make 16#usize (t.elements.val.map lift_fe) (by simp)).val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Std.Array.make 16#usize (scratch1.elements.val.map lift_fe) (by simp)).val[ℓ]!) + (lift_fe_mont zeta_r))) + apply List.ext_getElem + · simp + · intro ℓ hℓ1 _ + have hℓ : ℓ < 16 := by + have : ℓ < (List.range 16).length := hℓ1 + simpa using this + rw [List.getElem_map, List.getElem_range, + List.getElem_map, List.getElem_range] + congr 1 + have h_s1_len : scratch1.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length scratch1 + have h_s2_len : scratch2.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length scratch2 + have h_s1_idx : (Std.Array.make 16#usize (scratch1.elements.val.map lift_fe) + (by simp)).val[ℓ]! = lift_fe (scratch1.elements.val[ℓ]!) := by + show (scratch1.elements.val.map lift_fe)[ℓ]! = _ + have h_lhs_len : (scratch1.elements.val.map lift_fe).length = 16 := by + simp [List.length_map, h_s1_len] + rw [getElem!_pos _ ℓ (by rw [h_lhs_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos scratch1.elements.val ℓ (by rw [h_s1_len]; exact hℓ)] + have h_s2_idx : (Std.Array.make 16#usize (scratch2.elements.val.map lift_fe) + (by simp)).val[ℓ]! = lift_fe (scratch2.elements.val[ℓ]!) := by + show (scratch2.elements.val.map lift_fe)[ℓ]! = _ + have h_lhs_len : (scratch2.elements.val.map lift_fe).length = 16 := by + simp [List.length_map, h_s2_len] + rw [getElem!_pos _ ℓ (by rw [h_lhs_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos scratch2.elements.val ℓ (by rw [h_s2_len]; exact hℓ)] + rw [h_s1_idx, h_s2_idx] + have h_modq : libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + (scratch2.elements.val[ℓ]!).val + ((scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) 3329 := by + have h_per := (h_s2_per ℓ hℓ).2 + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + have h_169 : ((2^16 : Int) * 169) % 3329 = 1 := by decide + have h_rmul : ((scratch2.elements.val[ℓ]!).val * (2^16 : Int) * 169) % 3329 + = ((scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 := by + have h1 : ((scratch2.elements.val[ℓ]!).val * (2^16 : Int) * 169) % 3329 + = ((scratch2.elements.val[ℓ]!).val * (2^16 : Int)) % 3329 * 169 % 3329 := by + rw [Int.mul_emod]; simp + have h2 : ((scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 + = ((scratch1.elements.val[ℓ]!).val * zeta_r.val) % 3329 * 169 % 3329 := by + rw [Int.mul_emod]; simp + rw [h1, h2, h_per] + have h_lhs : ((scratch2.elements.val[ℓ]!).val * (2^16 : Int) * 169) % 3329 + = (scratch2.elements.val[ℓ]!).val % 3329 := by + have h_mul_assoc : (scratch2.elements.val[ℓ]!).val * (2^16 : Int) * 169 + = (scratch2.elements.val[ℓ]!).val * ((2^16 : Int) * 169) := by ring + rw [h_mul_assoc, Int.mul_emod, h_169]; simp + have h_zsub : + ((scratch2.elements.val[ℓ]!).val + - (scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 = 0 := by + have h_sub_emod : ((scratch2.elements.val[ℓ]!).val + - (scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329 + = ((scratch2.elements.val[ℓ]!).val % 3329 + - ((scratch1.elements.val[ℓ]!).val * zeta_r.val * 169) % 3329) % 3329 := by + rw [Int.sub_emod] + rw [h_sub_emod, ← h_lhs, h_rmul]; simp + exact h_zsub + exact lift_fe_mul_pure_mont_eq (scratch1.elements.val[ℓ]!) zeta_r + (scratch2.elements.val[ℓ]!) h_modq + · -- (c) For k ≠ a, k ≠ b: c3[k] = coefficients[k]. + intro k hk hka hkb + show c3.val[k]! = coefficients.val[k]! + have h_step1 : c3.val[k]! = c2.val[k]! := by + show (c2.set b t4).val[k]! = c2.val[k]! + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne c2 b k t4 (fun h => hkb h.symm) + have h_step2 : c2.val[k]! = c1.val[k]! := by + show (c1.set a t2).val[k]! = c1.val[k]! + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne c1 a k t2 (fun h => hka h.symm) + have h_step3 : c1.val[k]! = coefficients.val[k]! := by + show (coefficients.set b t).val[k]! = coefficients.val[k]! + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne coefficients b k t (fun h => hkb h.symm) + rw [h_step1, h_step2, h_step3] + +/-! ### L3.4_plus' — Inner loop scaffolding. -/ + +namespace Layer4PlusInnerFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Inner loop accumulator: (re, scratch). -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector × + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC invariant for the inner loop. Parameters: + - `re0` : poly at start of inner loop (input chunks). + - `a_offset b_offset : Std.Usize` : the chunk-base offsets for this outer round. + - `step_vec : Std.Usize` : inner loop end (the # of butterflies in this round). + - `zeta : hacspec FE` : the zeta (canonical) for this round's butterflies. + + The invariant at inner iter `k`: + - For `j' < k`: chunks at `a_offset + j'` and `b_offset + j'` are butterflied. + - For `j' ≥ k`: chunks unchanged. + - All other chunks unchanged. -/ +def inv + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset : Std.Usize) + (zeta : hacspec_ml_kem.parameters.FieldElement) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + -- (a) a-side butterflies for j' < k. + (∀ j' : Nat, j' < k.val → + lift_chunk (acc.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + zeta) + -- (b) b-side butterflies for j' < k. + ∧ (∀ j' : Nat, j' < k.val → + lift_chunk (acc.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + zeta) + -- (c) chunks at positions not yet touched, OR completely outside the a/b range. + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < k.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + acc.1.coefficients.val[k']! = re0.coefficients.val[k']!)) + +/-- Step-post for the inner loop. -/ +def step_post + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset step_vec : Std.Usize) + (zeta : hacspec_ml_kem.parameters.FieldElement) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < step_vec.val ∧ iter'.«end» = step_vec + ∧ iter'.start.val = k.val + 1 + ∧ (inv re0 a_offset b_offset zeta iter'.start acc').holds + | .done y => (inv re0 a_offset b_offset zeta step_vec y).holds + +end Layer4PlusInnerFC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the inner loop. -/ +theorem ntt_at_layer_4_plus_inner_step_lemma_fc + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset step_vec : Std.Usize) (zeta_i1 : Std.Usize) + (h_zi1_lt : zeta_i1.val < 128) + (h_step_vec_pos : 1 ≤ step_vec.val) + (h_a_offset_b : a_offset.val + step_vec.val ≤ 16) + (h_b_offset_b : b_offset.val + step_vec.val ≤ 16) + (h_disjoint : a_offset.val + step_vec.val ≤ b_offset.val) + (h_pre_a : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[a_offset.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (h_pre_b : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[b_offset.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (acc : Layer4PlusInnerFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ step_vec.val) + (h_inv : (Layer4PlusInnerFC.inv re0 a_offset b_offset + (Spec.zeta_at zeta_i1.val) k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i1 a_offset b_offset + { start := k, «end» := step_vec } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k r ⌝ ⦄ := by + have h_coef_len : acc.1.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_acc_a, h_acc_b, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + by_cases h_lt : k.val < step_vec.val + · -- Some j = k branch. + obtain ⟨s, hs_val, h_iter_some⟩ := + Layer4PlusFC.iter_next_some_eq_gen k step_vec h_lt + -- (1) i ← a_offset + k. + have h_a_max : a_offset.val + k.val ≤ Std.Usize.max := by + have h_ab_b : a_offset.val + k.val ≤ 16 := by omega + scalar_tac + obtain ⟨i, h_i_eq, h_i_val⟩ := + Layer4PlusFC.usize_add_ok_eq a_offset k h_a_max + -- (2) i1 ← b_offset + k. + have h_b_max : b_offset.val + k.val ≤ Std.Usize.max := by + have h_bb_b : b_offset.val + k.val ≤ 16 := by omega + scalar_tac + obtain ⟨i1, h_i1_eq, h_i1_val⟩ := + Layer4PlusFC.usize_add_ok_eq b_offset k h_b_max + -- (3) zeta lookup. + obtain ⟨z, h_z_eq, h_z_v, h_z_bd, h_z_lift⟩ := + triple_exists_ok_fc (polynomial.zeta_fc zeta_i1 h_zi1_lt) + -- Now we need to derive bounds on acc.1.coefficients[i] and [i1] from h_pre + h_acc_undone. + have h_i_lt_16 : i.val < 16 := by rw [h_i_val]; omega + have h_i1_lt_16 : i1.val < 16 := by rw [h_i1_val]; omega + have h_i_ne_i1 : i.val ≠ i1.val := by + rw [h_i_val, h_i1_val] + have : a_offset.val + k.val < b_offset.val + k.val := by omega + omega + -- Bounds at i and i1 via h_acc_undone. + have h_acc_i_undone : acc.1.coefficients.val[i.val]! = re0.coefficients.val[i.val]! := by + apply h_acc_undone i.val h_i_lt_16 + intro j' hj' + constructor + · -- i.val = a_offset + k ≠ a_offset + j' since j' < k. + rw [h_i_val]; omega + · -- i.val = a_offset + k ≠ b_offset + j' since b_offset ≥ a_offset + step_vec > a_offset + k. + rw [h_i_val]; omega + have h_acc_i1_undone : acc.1.coefficients.val[i1.val]! = re0.coefficients.val[i1.val]! := by + apply h_acc_undone i1.val h_i1_lt_16 + intro j' hj' + constructor + · -- i1.val = b_offset + k ≠ a_offset + j' since a_offset + j' ≤ a_offset + step_vec - 1 < b_offset. + rw [h_i1_val]; omega + · -- i1.val = b_offset + k ≠ b_offset + j'. + rw [h_i1_val]; omega + -- Acc-level bound: each chunk in acc.1.coefficients has lane bounds ≤ 29439 + -- ONLY for chunks at unchanged positions. The butterflied chunks may grow, + -- but we don't need to access them in this iter (i, i1 are pristine). + have h_acc_at_i_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro ℓ hℓ + rw [h_acc_i_undone, h_i_val] + exact h_pre_a k.val h_lt ℓ hℓ + have h_acc_at_i1_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[i1.val]!).elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro ℓ hℓ + rw [h_acc_i1_undone, h_i1_val] + exact h_pre_b k.val h_lt ℓ hℓ + -- Convert h_z_bd to ≤ 1664 form for zeta_r bound. + have h_zeta_bnd : z.val.natAbs ≤ 1664 := h_z_bd + -- Apply the keystone with bounds at i, i1. + obtain ⟨r_pair, h_r_eq, h_r_a, h_r_b, h_r_undone⟩ := + triple_exists_ok_fc (ntt_layer_int_vec_step_fc + acc.1.coefficients i i1 acc.2 z h_i_lt_16 h_i1_lt_16 h_i_ne_i1 + h_zeta_bnd h_acc_at_i_bnd h_acc_at_i1_bnd) + -- r_pair is (new_coefs, scratch2). Build the new accumulator. + set acc' : Layer4PlusInnerFC.Acc := + (({ coefficients := r_pair.1 } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector), + r_pair.2) with hacc'_def + -- Compose body. + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i1 a_offset b_offset + { start := k, «end» := step_vec } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := step_vec } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show (do + let i ← a_offset + k + let i1 ← b_offset + k + let i2 ← libcrux_iot_ml_kem.polynomial.zeta zeta_i1 + let (a, scratch1) ← + libcrux_iot_ml_kem.ntt.ntt_layer_int_vec_step portable_ops_inst + acc.1.coefficients i i1 acc.2 i2 + Result.ok (ControlFlow.cont (({ start := s, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := a } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector), + scratch1))) = _ + rw [h_i_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_i1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_z_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_r_eq]; rfl + apply triple_of_ok_fc h_body + show Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k + (.cont (({ start := s, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer4PlusInnerFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + -- Show inv at s acc'. + show (Layer4PlusInnerFC.inv re0 a_offset b_offset + (Spec.zeta_at zeta_i1.val) s acc').holds + have h_inv_pure : + (∀ j' : Nat, j' < s.val → + lift_chunk (acc'.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ j' : Nat, j' < s.val → + lift_chunk (acc'.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < s.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + acc'.1.coefficients.val[k']! = re0.coefficients.val[k']!) := by + refine ⟨?_, ?_, ?_⟩ + · -- (a) a-side butterfly for j' < s.val. + intro j' hj' + rw [hs_val] at hj' + rcases Nat.lt_succ_iff_lt_or_eq.mp hj' with hj'_lt | hj'_eq + · -- j' < k.val: existing butterfly, position untouched in this step. + -- acc'.1.coefs = r_pair.1. New chunk at a_offset+j' is unchanged in r_pair + -- since a_offset+j' ≠ i (=a_offset+k) and ≠ i1 (=b_offset+k). + have h_ne_i : a_offset.val + j' ≠ i.val := by rw [h_i_val]; omega + have h_ne_i1 : a_offset.val + j' ≠ i1.val := by rw [h_i1_val]; omega + have h_pos : a_offset.val + j' < 16 := by omega + have h_unchanged : r_pair.1.val[a_offset.val + j']! + = acc.1.coefficients.val[a_offset.val + j']! := + h_r_undone (a_offset.val + j') h_pos h_ne_i h_ne_i1 + show lift_chunk (acc'.1.coefficients.val[a_offset.val + j']!) = _ + show lift_chunk (r_pair.1.val[a_offset.val + j']!) = _ + rw [h_unchanged] + exact h_acc_a j' hj'_lt + · -- j' = k.val: new butterfly at i = a_offset + k. + subst hj'_eq + -- acc'.1.coefs[a_offset + k.val] = r_pair.1[a_offset + k.val] = r_pair.1[i.val] (since i.val = a_offset+k). + show lift_chunk (acc'.1.coefficients.val[a_offset.val + k.val]!) = _ + show lift_chunk (r_pair.1.val[a_offset.val + k.val]!) = _ + have h_eq_i : a_offset.val + k.val = i.val := by rw [h_i_val] + rw [h_eq_i] + -- h_r_a : lift_chunk r_pair.1[i] = chunk_pair_butterfly_a_pure (lift_chunk acc.1[i]) (lift_chunk acc.1[i1]) (lift_fe_mont z). + rw [h_r_a] + rw [h_acc_i_undone, h_acc_i1_undone] + rw [h_i_val, h_i1_val] + -- Need lift_fe_mont z = Spec.zeta_at zeta_i1.val. + rw [h_z_lift] + · -- (b) b-side butterfly for j' < s.val. + intro j' hj' + rw [hs_val] at hj' + rcases Nat.lt_succ_iff_lt_or_eq.mp hj' with hj'_lt | hj'_eq + · have h_ne_i : b_offset.val + j' ≠ i.val := by rw [h_i_val]; omega + have h_ne_i1 : b_offset.val + j' ≠ i1.val := by rw [h_i1_val]; omega + have h_pos : b_offset.val + j' < 16 := by omega + have h_unchanged : r_pair.1.val[b_offset.val + j']! + = acc.1.coefficients.val[b_offset.val + j']! := + h_r_undone (b_offset.val + j') h_pos h_ne_i h_ne_i1 + show lift_chunk (acc'.1.coefficients.val[b_offset.val + j']!) = _ + show lift_chunk (r_pair.1.val[b_offset.val + j']!) = _ + rw [h_unchanged] + exact h_acc_b j' hj'_lt + · subst hj'_eq + show lift_chunk (acc'.1.coefficients.val[b_offset.val + k.val]!) = _ + show lift_chunk (r_pair.1.val[b_offset.val + k.val]!) = _ + have h_eq_i1 : b_offset.val + k.val = i1.val := by rw [h_i1_val] + rw [h_eq_i1] + rw [h_r_b] + rw [h_acc_i_undone, h_acc_i1_undone] + rw [h_i_val, h_i1_val] + rw [h_z_lift] + · -- (c) Other positions unchanged from re0. + intro k' hk' h_not_touched + have hk'_ne_i : k' ≠ i.val := by + have h_at_k : k.val < s.val := by rw [hs_val]; omega + have := (h_not_touched k.val h_at_k).1 + rw [h_i_val]; exact this + have hk'_ne_i1 : k' ≠ i1.val := by + have h_at_k : k.val < s.val := by rw [hs_val]; omega + have := (h_not_touched k.val h_at_k).2 + rw [h_i1_val]; exact this + -- acc'.1.coefs[k'] = r_pair.1[k']. r_pair.1[k'] = acc.1[k'] (k' ≠ i, i1). + show acc'.1.coefficients.val[k']! = re0.coefficients.val[k']! + show r_pair.1.val[k']! = re0.coefficients.val[k']! + have h_unchanged := h_r_undone k' hk' hk'_ne_i hk'_ne_i1 + rw [h_unchanged] + -- Now acc.1[k'] = re0[k'] from h_acc_undone (k' is not touched at any j' < k.val). + apply h_acc_undone k' hk' + intro j' hj' + -- j' < k.val. We have h_not_touched at j' (since j' < k.val < s.val). + have h_at_j' : j' < s.val := by rw [hs_val]; omega + exact h_not_touched j' h_at_j' + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- None branch: k ≥ step_vec, done. + have hk_ge : k.val ≥ step_vec.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = step_vec.val := by omega + have h_iter_none := Layer4PlusFC.iter_next_none_eq_gen k step_vec hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i1 a_offset b_offset + { start := k, «end» := step_vec } acc.1 acc.2 + = .ok (ControlFlow.done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := step_vec } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k (.done acc) + unfold Layer4PlusInnerFC.step_post + show (Layer4PlusInnerFC.inv re0 a_offset b_offset + (Spec.zeta_at zeta_i1.val) step_vec acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < step_vec.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + acc.1.coefficients.val[k']! = re0.coefficients.val[k']!) := by + refine ⟨?_, ?_, ?_⟩ + · intro j' hj'; rw [← hk_eq] at hj'; exact h_acc_a j' hj' + · intro j' hj'; rw [← hk_eq] at hj'; exact h_acc_b j' hj' + · intro k' hk' h_not_touched + apply h_acc_undone k' hk' + intro j' hj' + have h_at_j' : j' < step_vec.val := by rw [← hk_eq]; exact hj' + exact h_not_touched j' h_at_j' + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +/-! ### L3.4_plus' — Outer loop scaffolding. -/ + +namespace Layer4PlusOuterFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Outer loop accumulator: (zeta_i, re, scratch). -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector × + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC invariant for the outer loop. Parameters: + - `re0` : original poly. + - `zeta_i_0` : zeta_i at start of outer loop. + - `step_vec` : (1 << layer) / 16. + + The invariant at outer iter `k`: + - `acc.1.val = zeta_i_0.val + k.val` (zeta thread). + - For `round' < k.val`: both a-side and b-side chunks are butterflied. + - Chunks not in any touched pair are unchanged from re0. -/ +def inv + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_0 step_vec : Std.Usize) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = zeta_i_0.val + k.val + -- (a) For each completed round, chunks at a-side positions are butterflied. + ∧ (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.2.1.coefficients.val[2 * round' * step_vec.val + j']!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i_0.val + round' + 1))) + -- (b) For each completed round, chunks at b-side positions are butterflied. + ∧ (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i_0.val + round' + 1))) + -- (c) Chunks not touched in any round' < k.val are unchanged. + ∧ (∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc.2.1.coefficients.val[c]! = re0.coefficients.val[c]!)) + +/-- Step-post for the outer loop. -/ +def step_post + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_0 step_vec i_end : Std.Usize) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < i_end.val ∧ iter'.«end» = i_end + ∧ iter'.start.val = k.val + 1 + ∧ (inv re0 zeta_i_0 step_vec iter'.start acc').holds + | .done y => (inv re0 zeta_i_0 step_vec i_end y).holds + +end Layer4PlusOuterFC + +/-- Helper: chunks lifted via `re0` at index `2*round'*step_vec + j'` + (a-side) are exactly the original re0 chunks (since these positions + have not yet been touched by outer iter `round'`). -/ +theorem outer_acc_a_chunk_eq_re0 + (re0 acc : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) (zeta_i_0 step_vec : Std.Usize) + (h_undone : ∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc.coefficients.val[c]! = re0.coefficients.val[c]!) + (h_kbound : 2 * k.val * step_vec.val + 2 * step_vec.val ≤ 16) + (j' : Nat) (hj' : j' < step_vec.val) : + acc.coefficients.val[2 * k.val * step_vec.val + j']! = re0.coefficients.val[2 * k.val * step_vec.val + j']! := by + apply h_undone (2 * k.val * step_vec.val + j') (by omega) + intro round' hround' j'' hj'' + -- round' < k, so 2*round'*step_vec + j'' < 2*k*step_vec (since 2*round'*step_vec + step_vec ≤ 2*k*step_vec). + -- Hence 2*k*step_vec + j' > 2*round'*step_vec + j'' if step_vec ≥ 1. + -- For the second leg: 2*round'*step_vec + step_vec + j'' < 2*round'*step_vec + 2*step_vec ≤ 2*k*step_vec. + -- So 2*k*step_vec + j' > all these. + constructor + · -- 2*k*step_vec + j' ≠ 2*round'*step_vec + j''. + have h1 : 2 * round' * step_vec.val + j'' < 2 * k.val * step_vec.val := by + have h_lt : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + omega + · -- 2*k*step_vec + j' ≠ 2*round'*step_vec + step_vec + j''. + have h1 : 2 * round' * step_vec.val + step_vec.val + j'' < 2 * k.val * step_vec.val := by + have : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + omega + +/-- b-side variant of the above helper. -/ +theorem outer_acc_b_chunk_eq_re0 + (re0 acc : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) (zeta_i_0 step_vec : Std.Usize) + (h_undone : ∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < k.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc.coefficients.val[c]! = re0.coefficients.val[c]!) + (h_kbound : 2 * k.val * step_vec.val + 2 * step_vec.val ≤ 16) + (h_step_vec_pos : 1 ≤ step_vec.val) + (j' : Nat) (hj' : j' < step_vec.val) : + acc.coefficients.val[2 * k.val * step_vec.val + step_vec.val + j']! + = re0.coefficients.val[2 * k.val * step_vec.val + step_vec.val + j']! := by + apply h_undone (2 * k.val * step_vec.val + step_vec.val + j') (by omega) + intro round' hround' j'' hj'' + constructor + · have h1 : 2 * round' * step_vec.val + j'' < 2 * k.val * step_vec.val := by + have : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + omega + · have h1 : 2 * round' * step_vec.val + step_vec.val + j'' < 2 * k.val * step_vec.val := by + have : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + omega + +set_option maxHeartbeats 16000000 in +/-- Inner loop spec wrapper: dispatches `loop_range_spec_usize` for the + inner loop, returning the final FC equations on the post poly. -/ +theorem ntt_at_layer_4_plus_inner_loop_fc + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset step_vec : Std.Usize) (zeta_i1 : Std.Usize) + (h_zi1_lt : zeta_i1.val < 128) + (h_step_vec_pos : 1 ≤ step_vec.val) + (h_a_offset_b : a_offset.val + step_vec.val ≤ 16) + (h_b_offset_b : b_offset.val + step_vec.val ≤ 16) + (h_disjoint : a_offset.val + step_vec.val ≤ b_offset.val) + (h_pre_a : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[a_offset.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (h_pre_b : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[b_offset.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0 + (vectortraitsOperationsInst := portable_ops_inst) + { start := 0#usize, «end» := step_vec } zeta_i1 re0 scratch a_offset b_offset + ⦃ ⇓ r => ⌜ + (∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < step_vec.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + r.1.coefficients.val[k']! = re0.coefficients.val[k']!) + ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i1 a_offset b_offset iter1 acc1.1 acc1.2) + (β := Layer4PlusInnerFC.Acc) + (re0, scratch) + 0#usize step_vec + (Layer4PlusInnerFC.inv re0 a_offset b_offset (Spec.zeta_at zeta_i1.val)) + (by + -- 0 ≤ step_vec. + have : (0#usize : Std.Usize).val = 0 := rfl + omega) + (by + -- Initial inv at k=0. + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_⟩ + · -- No a-side touched yet. + intro j' hj'; exact absurd hj' (Nat.not_lt_zero j') + · intro j' hj'; exact absurd hj' (Nat.not_lt_zero j') + · -- Acc = (re0, scratch); acc.1 = re0; all chunks unchanged trivially. + intro k' _ _; trivial) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : + (Layer4PlusInnerFC.inv re0 a_offset b_offset + (Spec.zeta_at zeta_i1.val) step_vec r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.1.coefficients.val[a_offset.val + j']!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.1.coefficients.val[b_offset.val + j']!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[a_offset.val + j']!)) + (lift_chunk (re0.coefficients.val[b_offset.val + j']!)) + (Spec.zeta_at zeta_i1.val)) + ∧ (∀ k' : Nat, k' < 16 → + (∀ j' : Nat, j' < step_vec.val → k' ≠ a_offset.val + j' ∧ k' ≠ b_offset.val + j') → + r.1.coefficients.val[k']! = re0.coefficients.val[k']!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer4PlusInnerFC.inv] using h_inv_holds + exact h_inv + · -- Step lemma dispatch. + intro acc k _h_ge h_le hinv + have h_step := ntt_at_layer_4_plus_inner_step_lemma_fc re0 a_offset b_offset step_vec + zeta_i1 h_zi1_lt h_step_vec_pos h_a_offset_b h_b_offset_b h_disjoint h_pre_a h_pre_b + acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusInnerFC.step_post] using hP + · have hP : Layer4PlusInnerFC.step_post re0 a_offset b_offset step_vec + (Spec.zeta_at zeta_i1.val) k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusInnerFC.step_post] using hP + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for the outer loop. Dispatches the + inner loop wrapper. -/ +theorem ntt_at_layer_4_plus_outer_step_lemma_fc + (re0 : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_0 step_vec i_end : Std.Usize) + (h_pre : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re0.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (h_step_vec_pos : 1 ≤ step_vec.val) + (h_step_vec_dvd : 2 * i_end.val * step_vec.val = 16) + (h_zeta_bnd : zeta_i_0.val + i_end.val ≤ 127) + (acc : Layer4PlusOuterFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ i_end.val) + (h_inv : (Layer4PlusOuterFC.inv re0 zeta_i_0 step_vec k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + step_vec { start := k, «end» := i_end } acc.1 acc.2.1 acc.2.2 + ⦃ ⇓ r => ⌜ Layer4PlusOuterFC.step_post re0 zeta_i_0 step_vec i_end k r ⌝ ⦄ := by + obtain ⟨h_zeta_acc, h_acc_a, h_acc_b, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + by_cases h_lt : k.val < i_end.val + · -- Some round = k branch. + obtain ⟨s, hs_val, h_iter_some⟩ := + Layer4PlusFC.iter_next_some_eq_gen k i_end h_lt + -- (1) zeta_i1 ← zeta_i + 1. + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um, h_zeta_acc]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + Layer4PlusFC.usize_add_ok_eq acc.1 1#usize h_z_max + have h_zi1_arith : zi1.val = zeta_i_0.val + k.val + 1 := by + rw [h_zi1_val, h_um, h_zeta_acc] + have h_zi1_lt_128 : zi1.val < 128 := by + rw [h_zi1_arith]; omega + -- (2) i ← round * 2. + have h_um2 : (2#usize : Std.Usize).val = 2 := rfl + have h_i_max : k.val * (2#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um2] + have h_k_b : k.val * 2 ≤ 16 := by + have : k.val ≤ 8 := by + -- k < i_end and i_end * step_vec * 2 = 16, so k ≤ 8. + have : i_end.val ≤ 8 := by + have : i_end.val * step_vec.val * 2 = 16 := by rw [Nat.mul_assoc] at h_step_vec_dvd; nlinarith + nlinarith + omega + omega + scalar_tac + obtain ⟨ii, h_ii_eq, h_ii_val⟩ := + Layer4PlusFC.usize_mul_ok_eq k 2#usize h_i_max + have h_ii_arith : ii.val = 2 * k.val := by rw [h_ii_val, h_um2, Nat.mul_comm] + -- (3) a_offset ← ii * step_vec. + have h_a_max : ii.val * step_vec.val ≤ Std.Usize.max := by + rw [h_ii_arith] + have h_b : 2 * k.val * step_vec.val ≤ 16 := by + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + scalar_tac + obtain ⟨ao, h_ao_eq, h_ao_val⟩ := + Layer4PlusFC.usize_mul_ok_eq ii step_vec h_a_max + have h_ao_arith : ao.val = 2 * k.val * step_vec.val := by + rw [h_ao_val, h_ii_arith] + -- (4) b_offset ← a_offset + step_vec. + have h_b_max : ao.val + step_vec.val ≤ Std.Usize.max := by + rw [h_ao_arith] + have h_b : 2 * k.val * step_vec.val + step_vec.val ≤ 16 := by + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + scalar_tac + obtain ⟨bo, h_bo_eq, h_bo_val⟩ := + Layer4PlusFC.usize_add_ok_eq ao step_vec h_b_max + have h_bo_arith : bo.val = 2 * k.val * step_vec.val + step_vec.val := by + rw [h_bo_val, h_ao_arith] + -- Now dispatch the inner loop. + have h_a_offset_b : ao.val + step_vec.val ≤ 16 := by + rw [h_ao_arith] + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + have h_b_offset_b : bo.val + step_vec.val ≤ 16 := by + rw [h_bo_arith] + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + have h_disjoint : ao.val + step_vec.val ≤ bo.val := by + rw [h_ao_arith, h_bo_arith] + -- Localized chunk bounds for inner loop input: acc.2.1. + have h_2kstep_bnd : 2 * k.val * step_vec.val + 2 * step_vec.val ≤ 16 := by + have : (k.val + 1) * (2 * step_vec.val) ≤ i_end.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + have h_acc_a_eq : ∀ j : Nat, j < step_vec.val → + acc.2.1.coefficients.val[ao.val + j]! = re0.coefficients.val[ao.val + j]! := by + intro j hj + rw [h_ao_arith] + exact outer_acc_a_chunk_eq_re0 re0 acc.2.1 k zeta_i_0 step_vec h_acc_undone + h_2kstep_bnd j hj + have h_acc_b_eq : ∀ j : Nat, j < step_vec.val → + acc.2.1.coefficients.val[bo.val + j]! = re0.coefficients.val[bo.val + j]! := by + intro j hj + rw [h_bo_arith] + exact outer_acc_b_chunk_eq_re0 re0 acc.2.1 k zeta_i_0 step_vec h_acc_undone + h_2kstep_bnd h_step_vec_pos j hj + have h_pre_a : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.1.coefficients.val[ao.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro j hj ℓ hℓ + rw [h_acc_a_eq j hj] + apply h_pre _ _ ℓ hℓ + rw [h_ao_arith]; omega + have h_pre_b : ∀ j : Nat, j < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.1.coefficients.val[bo.val + j]!).elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro j hj ℓ hℓ + rw [h_acc_b_eq j hj] + apply h_pre _ _ ℓ hℓ + rw [h_bo_arith]; omega + -- Dispatch inner loop. + have h_inner := ntt_at_layer_4_plus_inner_loop_fc acc.2.1 acc.2.2 ao bo step_vec zi1 + h_zi1_lt_128 h_step_vec_pos h_a_offset_b h_b_offset_b h_disjoint h_pre_a h_pre_b + obtain ⟨r_pair, h_r_eq, h_r_a, h_r_b, h_r_undone⟩ := + triple_exists_ok_fc h_inner + -- Compose body. + set acc' : Layer4PlusOuterFC.Acc := + (zi1, r_pair.1, r_pair.2) with hacc'_def + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + step_vec { start := k, «end» := i_end } acc.1 acc.2.1 acc.2.2 + = .ok (ControlFlow.cont (({ start := s, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := i_end } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show (do + let zi1' ← acc.1 + 1#usize + let ii' ← k * 2#usize + let ao' ← ii' * step_vec + let bo' ← ao' + step_vec + let (re1, scratch1) ← + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0 + (vectortraitsOperationsInst := portable_ops_inst) + { start := 0#usize, «end» := step_vec } zi1' acc.2.1 acc.2.2 ao' bo' + .ok (ControlFlow.cont (({ start := s, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize), + zi1', re1, scratch1))) = _ + rw [h_zi1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_ii_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_ao_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_bo_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_r_eq]; rfl + apply triple_of_ok_fc h_body + show Layer4PlusOuterFC.step_post re0 zeta_i_0 step_vec i_end k + (.cont (({ start := s, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold Layer4PlusOuterFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (Layer4PlusOuterFC.inv re0 zeta_i_0 step_vec s acc').holds + have h_inv_pure : + acc'.1.val = zeta_i_0.val + s.val + ∧ (∀ round' : Nat, round' < s.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc'.2.1.coefficients.val[2 * round' * step_vec.val + j']!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i_0.val + round' + 1))) + ∧ (∀ round' : Nat, round' < s.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc'.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i_0.val + round' + 1))) + ∧ (∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < s.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc'.2.1.coefficients.val[c]! = re0.coefficients.val[c]!) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · -- zeta thread. + show zi1.val = zeta_i_0.val + s.val + rw [h_zi1_arith, hs_val]; ring + · -- a-side butterflies for round' < s.val. + intro round' hround' j' hj' + rw [hs_val] at hround' + rcases Nat.lt_succ_iff_lt_or_eq.mp hround' with hround'_lt | hround'_eq + · -- round' < k.val: use h_acc_a after observing acc'.2.1.coefs = r_pair.1. + -- r_pair.1[c] = acc.2.1.coefs[c] for c not touched in this inner loop. + have h_pos : 2 * round' * step_vec.val + j' < 16 := by + have h_rb : 2 * round' * step_vec.val + 2 * step_vec.val + ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_ne_a : ∀ j : Nat, j < step_vec.val → + 2 * round' * step_vec.val + j' ≠ ao.val + j := by + intro j hj + rw [h_ao_arith] + have h1 : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_ne_b : ∀ j : Nat, j < step_vec.val → + 2 * round' * step_vec.val + j' ≠ bo.val + j := by + intro j hj + rw [h_bo_arith] + have h1 : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_step_unc : r_pair.1.coefficients.val[2 * round' * step_vec.val + j']! + = acc.2.1.coefficients.val[2 * round' * step_vec.val + j']! := + h_r_undone (2 * round' * step_vec.val + j') h_pos + (fun j hj => ⟨h_ne_a j hj, h_ne_b j hj⟩) + show lift_chunk (acc'.2.1.coefficients.val[2 * round' * step_vec.val + j']!) = _ + show lift_chunk (r_pair.1.coefficients.val[2 * round' * step_vec.val + j']!) = _ + rw [h_step_unc] + exact h_acc_a round' hround'_lt j' hj' + · -- round' = k.val: new butterfly at (2*k*step_vec + j', 2*k*step_vec + step_vec + j'). + subst hround'_eq + show lift_chunk (acc'.2.1.coefficients.val[2 * k.val * step_vec.val + j']!) = _ + show lift_chunk (r_pair.1.coefficients.val[2 * k.val * step_vec.val + j']!) = _ + rw [show (2 * k.val * step_vec.val + j' : Nat) = ao.val + j' from by rw [h_ao_arith]] + rw [h_r_a j' hj'] + rw [h_acc_a_eq j' hj', h_acc_b_eq j' hj'] + rw [show (ao.val + j' : Nat) = 2 * k.val * step_vec.val + j' from by rw [h_ao_arith]] + rw [show (bo.val + j' : Nat) = 2 * k.val * step_vec.val + step_vec.val + j' from by rw [h_bo_arith]] + rw [show zi1.val = zeta_i_0.val + k.val + 1 from h_zi1_arith] + · -- b-side butterflies for round' < s.val. + intro round' hround' j' hj' + rw [hs_val] at hround' + rcases Nat.lt_succ_iff_lt_or_eq.mp hround' with hround'_lt | hround'_eq + · -- round' < k.val: use h_acc_b. + have h_pos : 2 * round' * step_vec.val + step_vec.val + j' < 16 := by + have h_rb : 2 * round' * step_vec.val + 2 * step_vec.val + ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_ne_a : ∀ j : Nat, j < step_vec.val → + 2 * round' * step_vec.val + step_vec.val + j' ≠ ao.val + j := by + intro j hj + rw [h_ao_arith] + have h1 : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_ne_b : ∀ j : Nat, j < step_vec.val → + 2 * round' * step_vec.val + step_vec.val + j' ≠ bo.val + j := by + intro j hj + rw [h_bo_arith] + have h1 : 2 * round' * step_vec.val + 2 * step_vec.val ≤ 2 * k.val * step_vec.val := by + have h_pos : (round' + 1) * (2 * step_vec.val) ≤ k.val * (2 * step_vec.val) := by + apply Nat.mul_le_mul_right; omega + nlinarith + omega + have h_step_unc : + r_pair.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']! + = acc.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']! := + h_r_undone (2 * round' * step_vec.val + step_vec.val + j') h_pos + (fun j hj => ⟨h_ne_a j hj, h_ne_b j hj⟩) + show lift_chunk (acc'.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) = _ + show lift_chunk (r_pair.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) = _ + rw [h_step_unc] + exact h_acc_b round' hround'_lt j' hj' + · subst hround'_eq + show lift_chunk (acc'.2.1.coefficients.val[2 * k.val * step_vec.val + step_vec.val + j']!) = _ + show lift_chunk (r_pair.1.coefficients.val[2 * k.val * step_vec.val + step_vec.val + j']!) = _ + rw [show (2 * k.val * step_vec.val + step_vec.val + j' : Nat) = bo.val + j' from by rw [h_bo_arith]] + rw [h_r_b j' hj'] + rw [h_acc_a_eq j' hj', h_acc_b_eq j' hj'] + rw [show (ao.val + j' : Nat) = 2 * k.val * step_vec.val + j' from by rw [h_ao_arith]] + rw [show (bo.val + j' : Nat) = 2 * k.val * step_vec.val + step_vec.val + j' from by rw [h_bo_arith]] + rw [show zi1.val = zeta_i_0.val + k.val + 1 from h_zi1_arith] + · -- Untouched chunks. + intro c hc h_not_touched + -- acc'.2.1.coefs[c] = r_pair.1[c] = acc.2.1[c] = re0[c]. + show acc'.2.1.coefficients.val[c]! = re0.coefficients.val[c]! + show r_pair.1.coefficients.val[c]! = re0.coefficients.val[c]! + -- c is not touched at round' = k (since round' < s = k+1 includes k). + have h_at_k : k.val < s.val := by rw [hs_val]; omega + have h_ne_a_k : ∀ j : Nat, j < step_vec.val → c ≠ ao.val + j := by + intro j hj; rw [h_ao_arith] + exact (h_not_touched k.val h_at_k j hj).1 + have h_ne_b_k : ∀ j : Nat, j < step_vec.val → c ≠ bo.val + j := by + intro j hj; rw [h_bo_arith] + exact (h_not_touched k.val h_at_k j hj).2 + have h_step_unc : r_pair.1.coefficients.val[c]! = acc.2.1.coefficients.val[c]! := + h_r_undone c hc (fun j hj => ⟨h_ne_a_k j hj, h_ne_b_k j hj⟩) + rw [h_step_unc] + apply h_acc_undone c hc + intro round' hround' j' hj' + -- round' < k, so this is in the prior touched set; not at this c. + have h_at_r : round' < s.val := by rw [hs_val]; omega + exact h_not_touched round' h_at_r j' hj' + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- None branch: k ≥ i_end, done. + have hk_ge : k.val ≥ i_end.val := Nat.not_lt.mp h_lt + have hk_eq : k.val = i_end.val := by omega + have h_iter_none := Layer4PlusFC.iter_next_none_eq_gen k i_end hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) + step_vec { start := k, «end» := i_end } acc.1 acc.2.1 acc.2.2 + = .ok (ControlFlow.done (acc.1, acc.2.1, acc.2.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := i_end } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := i_end } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2.1, acc.2.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_fc h_body + show Layer4PlusOuterFC.step_post re0 zeta_i_0 step_vec i_end k (.done acc) + unfold Layer4PlusOuterFC.step_post + show (Layer4PlusOuterFC.inv re0 zeta_i_0 step_vec i_end acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + acc.1.val = zeta_i_0.val + i_end.val + ∧ (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.2.1.coefficients.val[2 * round' * step_vec.val + j']!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i_0.val + round' + 1))) + ∧ (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (acc.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re0.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i_0.val + round' + 1))) + ∧ (∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + acc.2.1.coefficients.val[c]! = re0.coefficients.val[c]!) := by + refine ⟨?_, ?_, ?_, ?_⟩ + · rw [h_zeta_acc, hk_eq] + · intro round' hround'; rw [← hk_eq] at hround'; exact h_acc_a round' hround' + · intro round' hround'; rw [← hk_eq] at hround'; exact h_acc_b round' hround' + · intro c hc h_nt + apply h_acc_undone c hc + intro round' hround' j' hj' + have : round' < i_end.val := by rw [← hk_eq]; exact hround' + exact h_nt round' this j' hj' + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L3.4_plus' — `ntt_at_layer_4_plus` PortableVector-specialised FC equation, + parameterized over `layer ∈ {4, 5, 6}`. + + Nested-loop pattern: outer over `round ∈ 0..(128 >>> layer)`, inner + over `j ∈ 0..((1 <<< layer) / 16)`. Each inner iter butterflies + chunks at positions `(round*2*step_vec + j, round*2*step_vec + step_vec + j)` + with `Spec.zeta_at (zeta_i + round + 1)`. zeta_i advances by `128 >>> layer` + across the entire call. + + **Preconditions** (load-bearing): + - `h_layer` : layer in 4..7 (validity of the nested-loop shape). For + layer=7, `step_vec=8`, `i_end=1`, so the outer loop runs one iteration + and applies a single chunk-pair butterfly at chunks `(0+j, 8+j)` for + `j ∈ 0..8`, matching the dedicated `ntt_at_layer_7` impl up to the + zeta choice (Mont vs plain). The relaxation lets `ntt_vector_u` (L3.4) + reuse this theorem for its first call. + - `h_bnd` : per-lane input bound 29439. + - `h_zeta` : zeta_i.val + (128 >>> layer) ≤ 127 (zeta indices within + ZETAS table 0..127). + + Proof sketch: + 1. Unfold `ntt_at_layer_4_plus` driver. Resolve the three impl constants + `step ← 1#usize <<< layer`, `step_vec ← step / 16#usize`, + `i_end ← 128#usize >>> layer` to specific Std.Usize values via + `Aeneas.Std.UScalar.ShiftLeft_spec`, `UScalar.div_spec`, + `UScalar.ShiftRight_spec` (all sound for layer ∈ {4,5,6}). + 2. Unfold `ntt_at_layer_4_plus_loop0` and apply + `loop_range_spec_usize` with invariant `Layer4PlusOuterFC.inv` + (zeta-thread + per-round a/b butterflies + untouched). + 3. Initial inv at k=0: zeta-thread trivial, no round' < 0 absurd, + chunks unchanged trivially. + 4. Post entailment at k=i_end: build chunks_arr matching Spec layout + (`Spec.chunk_at_layer_4_plus_pure chunks0 layer zeta_fn c` per c). + For each c < 16, case-split on `c % (2*step_vec) < step_vec`: + - a-side: group := c/(2*step_vec), j' := offset. Use h_acc_a. + - b-side: group := c/(2*step_vec), j' := offset - step_vec. + Use h_acc_b. Reduce inner Spec lookups via chunk_at_lift_poly_fc. + Apply `flatten_chunks_eq_lift_poly_fc`. + 5. Step dispatch: apply `ntt_at_layer_4_plus_outer_step_lemma_fc` and + unwrap step_post via .cont / .done branches. -/ +@[spec high] +theorem ntt_at_layer_4_plus_portable_fc + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (layer : Std.Usize) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_bound : Std.Usize) + (h_layer : 4 ≤ layer.val ∧ layer.val ≤ 7) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 29439) + (h_zeta : zeta_i.val + (128 >>> layer.val) ≤ 127) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i re layer scratch initial_bound + ⦃ ⇓ p => ⌜ lift_poly p.2.1 = Spec.ntt_at_layer_4_plus_pure (lift_poly re) zeta_i layer ⌝ ⦄ := by + -- Layer bounds. + obtain ⟨h_layer_lo, h_layer_hi⟩ := h_layer + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus + -- Step 1: resolve `step ← 1#usize <<< layer`. + -- For layer ∈ {4,5,6}, layer.val < 64 = Usize.numBits, so shift is OK and + -- the value is (1 <<< layer.val) % 2^64 = 1 <<< layer.val. + have h_usize_bits : (Aeneas.Std.UScalarTy.Usize.numBits : Nat) = System.Platform.numBits := rfl + have h_layer_bits : layer.val < Aeneas.Std.UScalarTy.Usize.numBits := by + have h_p := System.Platform.numBits_eq + rcases h_p with h32 | h64 + · rw [h_usize_bits, h32]; omega + · rw [h_usize_bits, h64]; omega + have h_size_eq : Aeneas.Std.UScalar.size Aeneas.Std.UScalarTy.Usize = 2 ^ System.Platform.numBits := by + simp [Std.Usize.size, Usize.numBits] + -- Extract step via spec_imp_exists; we want step.val = 1 <<< layer.val. + have h_one_shl_pow : ((1#usize : Std.Usize).val <<< layer.val) < 2 ^ System.Platform.numBits := by + have h_one_eq : (1#usize : Std.Usize).val = 1 := rfl + rw [h_one_eq, Nat.shiftLeft_eq, Nat.one_mul] + have h_p := System.Platform.numBits_eq + rcases h_p with h32 | h64 + · rw [h32]; exact Nat.pow_lt_pow_right (by decide) (by omega) + · rw [h64]; exact Nat.pow_lt_pow_right (by decide) (by omega) + have h_step_ex : ∃ step : Std.Usize, + ((1#usize : Std.Usize) <<< layer : Result Std.Usize) = .ok step + ∧ step.val = 1 <<< layer.val := by + have hT := Aeneas.Std.UScalar.ShiftLeft_spec (1#usize : Std.Usize) layer + (Aeneas.Std.UScalar.size Aeneas.Std.UScalarTy.Usize) h_layer_bits rfl + obtain ⟨z, h_eq, h_v_mod, _h_bv⟩ := Std.WP.spec_imp_exists hT + refine ⟨z, h_eq, ?_⟩ + have h_one_eq : (1#usize : Std.Usize).val = 1 := rfl + rw [h_v_mod, h_one_eq, h_size_eq, Nat.mod_eq_of_lt] + rw [h_one_eq] at h_one_shl_pow + exact h_one_shl_pow + obtain ⟨step, h_step_eq, h_step_val⟩ := h_step_ex + rw [h_step_eq] + simp only [Aeneas.Std.bind_tc_ok] + -- Step 2: resolve `step_vec ← step / 16#usize`. + have h_16_nz : ((16#usize : Std.Usize).val : Nat) ≠ 0 := by decide + have h_step_pos : 1 ≤ step.val := by + rw [h_step_val, Nat.shiftLeft_eq, Nat.one_mul] + exact Nat.one_le_pow _ _ (by decide : (0:Nat) < 2) + obtain ⟨step_vec, h_step_vec_eq, h_step_vec_val⟩ := + Aeneas.Std.UScalar.div_spec step h_16_nz + rw [h_step_vec_eq] + simp only [Aeneas.Std.bind_tc_ok] + -- Compute step_vec.val. + have h_step_vec_arith : step_vec.val = (1 <<< layer.val) / 16 := by + have h_16_eq : (16#usize : Std.Usize).val = 16 := rfl + rw [h_step_vec_val, h_step_val, h_16_eq] + -- Step 3: resolve `i ← 128#usize >>> layer`. + obtain ⟨i_end, h_i_end_eq, h_i_end_val, _h_i_end_bv⟩ := + Std.WP.spec_imp_exists (Aeneas.Std.UScalar.ShiftRight_spec (128#usize : Std.Usize) layer + h_layer_bits) + rw [h_i_end_eq] + have h_i_end_arith : i_end.val = 128 >>> layer.val := h_i_end_val + -- Step 4: positivity & dvd facts uniformly across layer ∈ {4,5,6}. + have h_step_vec_pos : 1 ≤ step_vec.val := by + rw [h_step_vec_arith] + interval_cases layer.val <;> decide + have h_step_vec_dvd : 2 * i_end.val * step_vec.val = 16 := by + rw [h_i_end_arith, h_step_vec_arith] + interval_cases layer.val <;> decide + have h_i_end_pos : 1 ≤ i_end.val := by + rw [h_i_end_arith] + interval_cases layer.val <;> decide + have h_zeta_bnd : zeta_i.val + i_end.val ≤ 127 := by + rw [h_i_end_arith]; exact h_zeta + -- Step 5: unfold the outer loop and apply loop_range_spec_usize. + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0 + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + (vectortraitsOperationsInst := portable_ops_inst) step_vec + iter1 acc1.1 acc1.2.1 acc1.2.2) + (β := Layer4PlusOuterFC.Acc) + (zeta_i, re, scratch) + 0#usize i_end + (Layer4PlusOuterFC.inv re zeta_i step_vec) + (by + have h_zero : (0#usize : Std.Usize).val = 0 := rfl + omega) + (by + -- Initial inv at k=0. + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_, ?_⟩ + · -- zeta-thread: zeta_i.val = zeta_i.val + 0. + show zeta_i.val = zeta_i.val + (0#usize : Std.Usize).val + show zeta_i.val = zeta_i.val + 0 + omega + · -- No round' < 0 absurd. + intro round' hround' _ _ + exact absurd hround' (Nat.not_lt_zero round') + · intro round' hround' _ _ + exact absurd hround' (Nat.not_lt_zero round') + · -- All chunks unchanged trivially. + intro _ _ _; trivial) + ?_) + · -- Post entailment: at k = i_end, build chunks_arr matching Spec and apply + -- flatten_chunks_eq_lift_poly_fc. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (Layer4PlusOuterFC.inv re zeta_i step_vec i_end r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + r.1.val = zeta_i.val + i_end.val + ∧ (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.2.1.coefficients.val[2 * round' * step_vec.val + j']!) + = Spec.chunk_pair_butterfly_a_pure + (lift_chunk (re.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i.val + round' + 1))) + ∧ (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + lift_chunk (r.2.1.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!) + = Spec.chunk_pair_butterfly_b_pure + (lift_chunk (re.coefficients.val[2 * round' * step_vec.val + j']!)) + (lift_chunk (re.coefficients.val[2 * round' * step_vec.val + step_vec.val + j']!)) + (Spec.zeta_at (zeta_i.val + round' + 1))) + ∧ (∀ c : Nat, c < 16 → + (∀ round' : Nat, round' < i_end.val → + ∀ j' : Nat, j' < step_vec.val → + c ≠ 2 * round' * step_vec.val + j' + ∧ c ≠ 2 * round' * step_vec.val + step_vec.val + j') → + r.2.1.coefficients.val[c]! = re.coefficients.val[c]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + Layer4PlusOuterFC.inv] using h_inv_holds + obtain ⟨_h_zeta_done, h_done_a, h_done_b, _h_done_undone⟩ := h_inv + -- Build chunks_arr matching the Spec layout. + unfold Spec.ntt_at_layer_4_plus_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun c => + Spec.chunk_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at (lift_poly re))) + (by simp)) + layer + (fun group => Spec.zeta_at (zeta_i.val + group + 1)) + c)) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16; simp + -- Inner chunks0 lookup at index k reduces via chunk_at_lift_poly_fc. + have h_chunks0_at : ∀ k : Nat, k < 16 → + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at (lift_poly re))) + (by simp)).val[k]! + = lift_chunk (re.coefficients.val[k]!) := by + intro k hk + have h_len_map : ((List.range 16).map (Spec.chunk_at (lift_poly re))).length = 16 := by simp + show ((List.range 16).map (Spec.chunk_at (lift_poly re)))[k]! = _ + rw [getElem!_pos _ k (by rw [h_len_map]; exact hk)] + rw [List.getElem_map, List.getElem_range] + exact chunk_at_lift_poly_fc re k hk + -- Now prove h_chunks_get pointwise. Two cases: a-side and b-side. + have h_chunks_get : ∀ c : Nat, (hc : c < 16) → + chunks_arr.val[c]'(by rw [h_chunks_len]; exact hc) + = lift_chunk (r.2.1.coefficients.val[c]!) := by + intro c hc + show ((List.range 16).map (fun c => + Spec.chunk_at_layer_4_plus_pure + (Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at (lift_poly re))) + (by simp)) + layer + (fun group => Spec.zeta_at (zeta_i.val + group + 1)) + c))[c]'_ = _ + rw [List.getElem_map, List.getElem_range] + -- Unfold Spec.chunk_at_layer_4_plus_pure. + unfold Spec.chunk_at_layer_4_plus_pure + -- Use abbreviation for step_vec_val. + set sv := (1 <<< layer.val) / 16 with hsv_def + have hsv_eq : sv = step_vec.val := by rw [hsv_def, h_step_vec_arith] + -- Reveal the if condition. + simp only [] + set group := c / (2 * sv) + set offset := c % (2 * sv) + have h_2sv_pos : 0 < 2 * sv := by rw [hsv_eq]; omega + have h_c_eq : 2 * sv * group + offset = c := by + show 2 * sv * (c / (2 * sv)) + c % (2 * sv) = c + exact Nat.div_add_mod c (2 * sv) + have h_off_lt : offset < 2 * sv := Nat.mod_lt _ h_2sv_pos + -- group < i_end.val: from c < 16 = 2 * i_end.val * sv. + have h_16_eq : 2 * i_end.val * sv = 16 := by + rw [hsv_eq]; exact h_step_vec_dvd + have h_group_lt : group < i_end.val := by + by_contra h_ge + push Not at h_ge + have h_ge2 : 2 * sv * i_end.val ≤ 2 * sv * group := Nat.mul_le_mul_left _ h_ge + have : c ≥ 2 * sv * i_end.val := by + have : 2 * sv * group ≤ c := by omega + omega + have h_rw : 2 * sv * i_end.val = 16 := by + have h : 2 * i_end.val * sv = 2 * sv * i_end.val := by ring + omega + omega + by_cases h_off_lt_sv : offset < sv + · -- a-side: c = 2*sv*group + offset = 2*group*sv + offset, partner c+sv. + simp only [if_pos h_off_lt_sv] + -- Need: chunks0[c]! and chunks0[c+sv]! reduce via h_chunks0_at. + have h_c_lt_16 : c < 16 := hc + have h_c_plus_sv_lt_16 : c + sv < 16 := by + have h_succ : 2 * sv * (group + 1) ≤ 2 * sv * i_end.val := Nat.mul_le_mul_left _ h_group_lt + have h_split : 2 * sv * (group + 1) = 2 * sv * group + 2 * sv := by ring + have h_eq_16 : 2 * sv * i_end.val = 16 := by + have : 2 * i_end.val * sv = 2 * sv * i_end.val := by ring + omega + omega + rw [h_chunks0_at c h_c_lt_16, h_chunks0_at (c + sv) h_c_plus_sv_lt_16] + -- Convert to round'=group, j'=offset shape for h_done_a. + have h_c_eq_a : c = 2 * group * step_vec.val + offset := by + rw [← hsv_eq] + calc c = 2 * sv * group + offset := h_c_eq.symm + _ = 2 * group * sv + offset := by ring_nf + have h_csv_eq_a : c + sv = 2 * group * step_vec.val + step_vec.val + offset := by + rw [h_c_eq_a]; rw [hsv_eq]; ring + have h_off_lt_sv' : offset < step_vec.val := by rw [← hsv_eq]; exact h_off_lt_sv + have h_done := h_done_a group h_group_lt offset h_off_lt_sv' + rw [h_csv_eq_a, h_c_eq_a] + exact h_done.symm + · -- b-side: c ≥ sv. Set j' := offset - sv. + simp only [if_neg h_off_lt_sv] + push Not at h_off_lt_sv + set j' := offset - sv with hj'_def + have hj'_lt_sv : j' < sv := by + show offset - sv < sv; omega + have h_off_eq : offset = sv + j' := by + show offset = sv + (offset - sv); omega + have h_c_lt_16 : c < 16 := hc + have h_c_minus_sv_lt_16 : c - sv < 16 := by omega + -- c = 2*sv*group + sv + j', c - sv = 2*sv*group + j'. + have h_c_eq_b : c = 2 * group * step_vec.val + step_vec.val + j' := by + rw [← hsv_eq] + have : c = 2 * sv * group + (sv + j') := by rw [← h_off_eq]; exact h_c_eq.symm + calc c = 2 * sv * group + (sv + j') := this + _ = 2 * group * sv + sv + j' := by ring + have h_cmsv_eq_b : c - sv = 2 * group * step_vec.val + j' := by + have h_sv_le_c : sv ≤ c := by + calc sv ≤ sv + j' := Nat.le_add_right _ _ + _ = offset := h_off_eq.symm + _ ≤ 2 * sv * group + offset := Nat.le_add_left _ _ + _ = c := h_c_eq + rw [← hsv_eq] + have h_full : c - sv = (2 * sv * group + (sv + j')) - sv := by rw [← h_off_eq, h_c_eq] + rw [h_full] + have h_simp : 2 * sv * group + (sv + j') - sv = 2 * sv * group + j' := by omega + rw [h_simp]; ring + rw [h_chunks0_at (c - sv) h_c_minus_sv_lt_16, h_chunks0_at c h_c_lt_16] + have h_j'_lt : j' < step_vec.val := by rw [← hsv_eq]; exact hj'_lt_sv + have h_done := h_done_b group h_group_lt j' h_j'_lt + rw [h_cmsv_eq_b, h_c_eq_b] + exact h_done.symm + -- Apply flatten_chunks_eq_lift_poly_fc. + have h_final := flatten_chunks_eq_lift_poly_fc r.2.1 chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Step lemma dispatch: apply ntt_at_layer_4_plus_outer_step_lemma_fc. + intro acc k _h_ge h_le hinv + have h_step := ntt_at_layer_4_plus_outer_step_lemma_fc re zeta_i step_vec i_end + h_bnd h_step_vec_pos h_step_vec_dvd h_zeta_bnd acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer4PlusOuterFC.step_post re zeta_i step_vec i_end k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusOuterFC.step_post] using hP + · have hP : Layer4PlusOuterFC.step_post re zeta_i step_vec i_end k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusOuterFC.step_post] using hP + +/-! ### Per-layer FC+bound combinators + + Each combinator pairs the FC equation (from FCTargets) with the + per-lane output bound (from legacy `Equivalence.ntt_at_layer_X_spec(_B)`), + so the downstream L3.3/L3.4 composition can chain them without + re-applying two Triples per sub-call. See SKILL §9.11. -/ + +set_option maxHeartbeats 800000 in +/-- L3.3-step-7 — layer-7 dedicated FC + bound combinator. + Input ≤ 3 (binomial-sampled), output ≤ 4803. + Pairs `ntt_at_layer_7_portable_fc` (FC eq) with + `libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_7_spec` (per-lane ≤ 4803). -/ +@[spec high] +theorem ntt_at_layer_7_portable_fc_strong + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_7 + (vectortraitsOperationsInst := portable_ops_inst) re scratch + ⦃ ⇓ p => ⌜ lift_poly p.1 = Spec.ntt_at_layer_7_pure (lift_poly re) + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 4803 ⌝ ⦄ := by + have h_fc := ntt_at_layer_7_portable_fc re scratch + (fun chunk hc k hk => by + have := h_pre chunk hc k hk; omega) + have h_bd := libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_7_spec re scratch h_pre + obtain ⟨r, h_eq, h_fc'⟩ := triple_exists_ok_fc h_fc + obtain ⟨r', h_eq', h_bd'⟩ := triple_exists_ok_fc h_bd + have h_rr : r = r' := by + have : (Result.ok r : Result _) = Result.ok r' := by rw [← h_eq, h_eq'] + cases this; rfl + subst h_rr + exact triple_of_ok_fc h_eq ⟨h_fc', h_bd'⟩ + +set_option maxHeartbeats 800000 in +/-- L3.3-step-{4,5,6} + L3.4-step-7 — generic `layer_4_plus` FC + bound combinator. + Parametric in `layer ∈ {4,5,6,7}`. Pairs `ntt_at_layer_4_plus_portable_fc` + (FC eq, h_zeta strict form) with `libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_4_plus_spec` + (output ≤ bnd.val + 3328, zeta-out = zeta_i + 128 >>> layer). -/ +@[spec high] +theorem ntt_at_layer_4_plus_portable_fc_strong + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (layer : Std.Usize) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (bnd : Std.Usize) + (h_layer : 4 ≤ layer.val ∧ layer.val ≤ 7) + (h_bnd : bnd.val ≤ 8 * 3328) + (h_zeta : zeta_i.val = (1 <<< (7 - layer.val)) - 1) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd.val) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus + (vectortraitsOperationsInst := portable_ops_inst) + zeta_i re layer scratch bnd + ⦃ ⇓ p => ⌜ lift_poly p.2.1 = Spec.ntt_at_layer_4_plus_pure (lift_poly re) zeta_i layer + ∧ p.1.val = zeta_i.val + 128 >>> layer.val + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd.val + 3328 ⌝ ⦄ := by + obtain ⟨h_layer_lo, h_layer_hi⟩ := h_layer + -- FC theorem h_bnd ceiling = 29439. bnd.val ≤ 8 * 3328 = 26624 ≤ 29439. + have h_pre_29439 : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 29439 := by + intro chunk hc k hk + have := h_pre chunk hc k hk + omega + -- FC theorem h_zeta: zeta_i.val + (128 >>> layer.val) ≤ 127. Derive from + -- h_zeta : zeta_i.val = (1 <<< (7 - layer.val)) - 1 by interval cases on layer. + have h_zeta_fc : zeta_i.val + (128 >>> layer.val) ≤ 127 := by + rw [h_zeta] + interval_cases layer.val <;> decide + have h_fc := ntt_at_layer_4_plus_portable_fc zeta_i re layer scratch bnd + ⟨h_layer_lo, h_layer_hi⟩ h_pre_29439 h_zeta_fc + have h_bd := libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_4_plus_spec + layer zeta_i re scratch bnd ⟨h_layer_lo, h_layer_hi⟩ h_bnd h_zeta h_pre + obtain ⟨r, h_eq, h_fc'⟩ := triple_exists_ok_fc h_fc + obtain ⟨r', h_eq', h_bd'⟩ := triple_exists_ok_fc h_bd + have h_rr : r = r' := by + have : (Result.ok r : Result _) = Result.ok r' := by rw [← h_eq, h_eq'] + cases this; rfl + subst h_rr + exact triple_of_ok_fc h_eq ⟨h_fc', h_bd'.1, h_bd'.2⟩ + +set_option maxHeartbeats 800000 in +/-- L3.3-step-3 / L3.4-step-3 — layer-3 FC + bound combinator. + Pairs `ntt_at_layer_3_portable_fc` (FC eq) with + `libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_3_spec_B` (zeta-out = 31, per-lane ≤ bnd+3328). -/ +@[spec high] +theorem ntt_at_layer_3_portable_fc_strong + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_bound : Std.Usize) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (h_zeta : zeta_i.val = 15) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_3 + (vectortraitsOperationsInst := portable_ops_inst) zeta_i re initial_bound + ⦃ ⇓ p => ⌜ lift_poly p.2 = Spec.ntt_layer_3_pure (lift_poly re) zeta_i + ∧ p.1.val = 31 + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + have h_fc := ntt_at_layer_3_portable_fc zeta_i re initial_bound + (fun chunk hc k hk => by have := h_pre chunk hc k hk; omega) + (by rw [h_zeta]; decide) + have h_bd := libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_3_spec_B + zeta_i re initial_bound bnd h_bnd h_zeta h_pre + obtain ⟨r, h_eq, h_fc'⟩ := triple_exists_ok_fc h_fc + obtain ⟨r', h_eq', h_bd'⟩ := triple_exists_ok_fc h_bd + have h_rr : r = r' := by + have : (Result.ok r : Result _) = Result.ok r' := by rw [← h_eq, h_eq'] + cases this; rfl + subst h_rr + exact triple_of_ok_fc h_eq ⟨h_fc', h_bd'.1, h_bd'.2⟩ + +set_option maxHeartbeats 800000 in +/-- L3.3-step-2 / L3.4-step-2 — layer-2 FC + bound combinator. + Pairs `ntt_at_layer_2_portable_fc` (FC eq) with + `libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_2_spec_B` (zeta-out = 63, per-lane ≤ bnd+3328). -/ +@[spec high] +theorem ntt_at_layer_2_portable_fc_strong + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_bound : Std.Usize) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (h_zeta : zeta_i.val = 31) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_2 + (vectortraitsOperationsInst := portable_ops_inst) zeta_i re initial_bound + ⦃ ⇓ p => ⌜ lift_poly p.2 = Spec.ntt_layer_2_pure (lift_poly re) zeta_i + ∧ p.1.val = 63 + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + have h_fc := ntt_at_layer_2_portable_fc zeta_i re initial_bound + (fun chunk hc k hk => by have := h_pre chunk hc k hk; omega) + (by rw [h_zeta]; decide) + have h_bd := libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_2_spec_B + zeta_i re initial_bound bnd h_bnd h_zeta h_pre + obtain ⟨r, h_eq, h_fc'⟩ := triple_exists_ok_fc h_fc + obtain ⟨r', h_eq', h_bd'⟩ := triple_exists_ok_fc h_bd + have h_rr : r = r' := by + have : (Result.ok r : Result _) = Result.ok r' := by rw [← h_eq, h_eq'] + cases this; rfl + subst h_rr + exact triple_of_ok_fc h_eq ⟨h_fc', h_bd'.1, h_bd'.2⟩ + +set_option maxHeartbeats 800000 in +/-- L3.3-step-1 / L3.4-step-1 — layer-1 FC + bound combinator. + Pairs `ntt_at_layer_1_portable_fc` (FC eq) with + `libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_1_spec_B` (zeta-out = 127, per-lane ≤ bnd+3328). -/ +@[spec high] +theorem ntt_at_layer_1_portable_fc_strong + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_bound : Std.Usize) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (h_zeta : zeta_i.val = 63) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_1 + (vectortraitsOperationsInst := portable_ops_inst) zeta_i re initial_bound + ⦃ ⇓ p => ⌜ lift_poly p.2 = Spec.ntt_layer_1_pure (lift_poly re) zeta_i + ∧ p.1.val = 127 + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + have h_fc := ntt_at_layer_1_portable_fc zeta_i re initial_bound + (fun chunk hc k hk => by have := h_pre chunk hc k hk; omega) + (by omega) + have h_bd := libcrux_iot_ml_kem.Polynomial.NttDrivers.ntt_at_layer_1_spec_B + zeta_i re initial_bound bnd h_bnd h_zeta h_pre + obtain ⟨r, h_eq, h_fc'⟩ := triple_exists_ok_fc h_fc + obtain ⟨r', h_eq', h_bd'⟩ := triple_exists_ok_fc h_bd + have h_rr : r = r' := by + have : (Result.ok r : Result _) = Result.ok r' := by rw [← h_eq, h_eq'] + cases this; rfl + subst h_rr + exact triple_of_ok_fc h_eq ⟨h_fc', h_bd'.1, h_bd'.2⟩ + + +/-- L3.3 — `ntt_binomially_sampled_ring_element` driver (7 layer + composition + barrett reduce). Projects on the poly component. + + Input bound `≤ 3`: from the upstream binomial sampler with η₁=2, + which produces samples in `[-2, 2]`. We use `≤ 3` (one slack) + to match `ntt_at_layer_7_spec`'s legacy bound precondition. + + Implementation chain: dedicated `ntt_at_layer_7` → 3× `ntt_at_layer_4_plus` + (layers 6, 5, 4) → `ntt_at_layer_3` → `ntt_at_layer_2` → + `ntt_at_layer_1` → `poly_barrett_reduce`. Each layer's FC equation + comes from FCTargets `ntt_at_layer_X_portable_fc`; the per-layer + output bound comes from legacy + `libcrux_iot_ml_kem.Equivalence.ntt_at_layer_X_spec(_B)`. -/ +@[spec] +theorem ntt_binomially_sampled_ring_element_fc + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ k : Nat, k < 16 → + ((re.coefficients.val[chunk]!).elements.val[k]!).val.natAbs ≤ 3) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_binomially_sampled_ring_element + (vectortraitsOperationsInst := portable_ops_inst) re scratch + ⦃ ⇓ p => ⌜ lift_poly p.1 = Spec.ntt_pure (lift_poly re) ⌝ ⦄ := by + -- Strategy: collect all step equations using *_fc_strong combinators + -- and arithmetic helpers, then assemble the full impl-body equation via + -- `unfold ; simp [...]`. Close the Triple with `triple_of_ok_fc h_body` + -- and prove the lift_poly equation by chaining FC equations through + -- `Spec.ntt_pure` plus the barrett bridge. + -- ============================================================= + -- Step 1: layer_7. re scratch → re1 scratch1. ≤ 3 → ≤ 4803. + -- ============================================================= + obtain ⟨⟨re1, scratch1⟩, h1_eq, h1_fc, h1_bnd⟩ := + triple_exists_ok_fc (ntt_at_layer_7_portable_fc_strong re scratch h_bnd) + dsimp only at h1_fc h1_bnd + -- ============================================================= + -- Step 2: layer_4_plus (zeta_i=1, layer=6, bnd=11207). ≤ 4803 → ≤ 14535. + -- ============================================================= + have h_re1_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re1.coefficients.val[i]!).elements.val[j]!).val.natAbs + ≤ (11207#usize : Std.Usize).val := by + intro i hi j hj + have hb := h1_bnd i hi j hj + show _ ≤ 11207 + omega + obtain ⟨⟨zeta_i1, re2, scratch2⟩, h2_eq, h2_fc, h2_zout, h2_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_4_plus_portable_fc_strong 1#usize re1 6#usize scratch1 11207#usize + (by decide) (by decide) (by decide) h_re1_loose) + dsimp only at h2_fc h2_zout h2_bnd + have h_zeta_i1 : zeta_i1.val = 3 := by rw [h2_zout]; decide + -- ============================================================= + -- Step 3: usize_add 11207 + 3328 = 14535. + -- ============================================================= + obtain ⟨i14535, hi14535_eq, hi14535_val⟩ := + usize_add_ok_eq_fc 11207#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 4: layer_4_plus (zeta_i1, layer=5, bnd=14535). ≤ 14535 → ≤ 17863. + -- ============================================================= + have h_re2_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re2.coefficients.val[i]!).elements.val[j]!).val.natAbs + ≤ i14535.val := by + intro i hi j hj + have hb := h2_bnd i hi j hj + have h11207 : (11207#usize : Std.Usize).val = 11207 := by decide + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + rw [h11207] at hb + rw [hi14535_val, h11207, h3328] + omega + have h_i14535_bnd : i14535.val ≤ 8 * 3328 := by + have h := hi14535_val + have h11207 : (11207#usize : Std.Usize).val = 11207 := by decide + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + rw [h11207, h3328] at h + omega + obtain ⟨⟨zeta_i2, re3, scratch3⟩, h4_eq, h4_fc, h4_zout, h4_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_4_plus_portable_fc_strong zeta_i1 re2 5#usize scratch2 i14535 + (by decide) h_i14535_bnd + (by rw [h_zeta_i1]; decide) h_re2_loose) + dsimp only at h4_fc h4_zout h4_bnd + have h_zeta_i2 : zeta_i2.val = 7 := by + rw [h4_zout, h_zeta_i1]; decide + -- ============================================================= + -- Step 5: usize_mul 2 * 3328 = 6656. + -- ============================================================= + obtain ⟨i6656, hi6656_eq, hi6656_val⟩ := + usize_mul_ok_eq_fc 2#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 6: usize_add 11207 + 6656 = 17863. + -- ============================================================= + obtain ⟨i17863, hi17863_eq, hi17863_val⟩ := + usize_add_ok_eq_fc 11207#usize i6656 (by have := hi6656_val; scalar_tac) + -- ============================================================= + -- Step 7: layer_4_plus (zeta_i2, layer=4, bnd=17863). ≤ 17863 → ≤ 21191. + -- ============================================================= + have h_re3_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re3.coefficients.val[i]!).elements.val[j]!).val.natAbs + ≤ i17863.val := by + intro i hi j hj + have hb := h4_bnd i hi j hj + have h11207 : (11207#usize : Std.Usize).val = 11207 := by decide + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + have h2 : (2#usize : Std.Usize).val = 2 := by decide + rw [hi14535_val, h11207, h3328] at hb + rw [hi17863_val, hi6656_val, h11207, h2, h3328] + omega + have h_i17863_bnd : i17863.val ≤ 8 * 3328 := by + have h := hi17863_val + have hm := hi6656_val + have h11207 : (11207#usize : Std.Usize).val = 11207 := by decide + have h2 : (2#usize : Std.Usize).val = 2 := by decide + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + rw [h11207] at h + rw [h2, h3328] at hm + omega + obtain ⟨⟨zeta_i3, re4, scratch4⟩, h7_eq, h7_fc, h7_zout, h7_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_4_plus_portable_fc_strong zeta_i2 re3 4#usize scratch3 i17863 + (by decide) h_i17863_bnd + (by rw [h_zeta_i2]; decide) h_re3_loose) + dsimp only at h7_fc h7_zout h7_bnd + have h_zeta_i3 : zeta_i3.val = 15 := by + rw [h7_zout, h_zeta_i2]; decide + -- ============================================================= + -- Step 8: usize_mul 3 * 3328 = 9984. + -- ============================================================= + obtain ⟨i9984, hi9984_eq, hi9984_val⟩ := + usize_mul_ok_eq_fc 3#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 9: usize_add 11207 + 9984 = 21191. + -- ============================================================= + obtain ⟨i21191, hi21191_eq, hi21191_val⟩ := + usize_add_ok_eq_fc 11207#usize i9984 (by have := hi9984_val; scalar_tac) + -- ============================================================= + -- Step 10: layer_3 (zeta_i3=15, bnd=21191 Nat). → ≤ 24519. zeta_out=31. + -- ============================================================= + have h_re4_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re4.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 21191 := by + intro i hi j hj + have hb := h7_bnd i hi j hj + have h11207 : (11207#usize : Std.Usize).val = 11207 := by decide + have h3328 : (3328#usize : Std.Usize).val = 3328 := by decide + have h2 : (2#usize : Std.Usize).val = 2 := by decide + rw [hi17863_val, hi6656_val, h11207, h2, h3328] at hb + omega + obtain ⟨⟨zeta_i4, re5⟩, h10_eq, h10_fc, h10_zout, h10_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_3_portable_fc_strong zeta_i3 re4 i21191 21191 + (by decide) h_zeta_i3 h_re4_loose) + dsimp only at h10_fc h10_zout h10_bnd + -- ============================================================= + -- Step 11: usize_mul 4 * 3328 = 13312. + -- ============================================================= + obtain ⟨i13312, hi13312_eq, hi13312_val⟩ := + usize_mul_ok_eq_fc 4#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 12: usize_add 11207 + 13312 = 24519. + -- ============================================================= + obtain ⟨i24519, hi24519_eq, hi24519_val⟩ := + usize_add_ok_eq_fc 11207#usize i13312 (by have := hi13312_val; scalar_tac) + -- ============================================================= + -- Step 13: layer_2 (zeta_i4=31, bnd=24519 Nat). → ≤ 27847. zeta_out=63. + -- ============================================================= + have h_re5_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re5.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 24519 := by + intro i hi j hj + have hb := h10_bnd i hi j hj + omega + obtain ⟨⟨zeta_i5, re6⟩, h13_eq, h13_fc, h13_zout, h13_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_2_portable_fc_strong zeta_i4 re5 i24519 24519 + (by decide) h10_zout h_re5_loose) + dsimp only at h13_fc h13_zout h13_bnd + -- ============================================================= + -- Step 14: usize_mul 5 * 3328 = 16640. + -- ============================================================= + obtain ⟨i16640, hi16640_eq, hi16640_val⟩ := + usize_mul_ok_eq_fc 5#usize 3328#usize (by scalar_tac) + -- ============================================================= + -- Step 15: usize_add 11207 + 16640 = 27847. + -- ============================================================= + obtain ⟨i27847, hi27847_eq, hi27847_val⟩ := + usize_add_ok_eq_fc 11207#usize i16640 (by have := hi16640_val; scalar_tac) + -- ============================================================= + -- Step 16: layer_1 (zeta_i5=63, bnd=27847 Nat). → ≤ 31175. zeta_out=127. + -- ============================================================= + have h_re6_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re6.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 27847 := by + intro i hi j hj + have hb := h13_bnd i hi j hj + omega + obtain ⟨⟨_zeta_i6, re7⟩, h16_eq, h16_fc, _h16_zout, h16_bnd⟩ := + triple_exists_ok_fc + (ntt_at_layer_1_portable_fc_strong zeta_i5 re6 i27847 27847 + (by decide) h13_zout h_re6_loose) + dsimp only at h16_fc h16_bnd + -- ============================================================= + -- Step 17: poly_barrett_reduce. ≤ 31175 ≤ 32767 → canonical residue. + -- ============================================================= + have h_re7_loose : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re7.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro chunk hc ℓ hℓ + have hb := h16_bnd chunk hc ℓ hℓ + omega + obtain ⟨re8, h17_eq, h17_fc⟩ := + triple_exists_ok_fc (poly_barrett_reduce_fc re7 h_re7_loose) + -- ============================================================= + -- Compose: derive the full impl `do`-block equation by simp-folding + -- all step equations into the unfolded body. + -- ============================================================= + have h_body : + libcrux_iot_ml_kem.ntt.ntt_binomially_sampled_ring_element + (vectortraitsOperationsInst := portable_ops_inst) re scratch + = .ok (re8, scratch4) := by + unfold libcrux_iot_ml_kem.ntt.ntt_binomially_sampled_ring_element + simp [h1_eq, h2_eq, h4_eq, h7_eq, h10_eq, h13_eq, h16_eq, h17_eq, + hi14535_eq, hi6656_eq, hi17863_eq, + hi9984_eq, hi21191_eq, + hi13312_eq, hi24519_eq, + hi16640_eq, hi27847_eq] + apply triple_of_ok_fc h_body + -- ============================================================= + -- Prove lift_poly equation by chaining FC equations through Spec.ntt_pure. + -- ============================================================= + show lift_poly re8 = Spec.ntt_pure (lift_poly re) + unfold Spec.ntt_pure + -- Bridge barrett: h17_fc : poly_barrett_reduce (lift_poly re7) = .ok (lift_poly re8). + have hB_bridge : + hacspec_ml_kem.polynomial.poly_barrett_reduce (lift_poly re7) + = .ok (Spec.Pure.polynomial.poly_barrett_reduce_pure (lift_poly re7)) := + Spec.Pure.polynomial.poly_barrett_reduce_eq_ok (lift_poly re7) + rw [hB_bridge] at h17_fc + have h_re8_eq : lift_poly re8 + = Spec.Pure.polynomial.poly_barrett_reduce_pure (lift_poly re7) := by + have h := h17_fc + exact (Aeneas.Std.Result.ok.injEq _ _).mp h.symm + -- zeta_i identifications: substitute zeta values into the spec chain via .val. + have h_zeta_i2 : zeta_i2.val = 7 := by rw [h4_zout, h_zeta_i1]; decide + have h_zeta_i3 : zeta_i3.val = 15 := by rw [h7_zout, h_zeta_i2]; decide + have h_zeta_eq1 : zeta_i1 = 3#usize := by + have := h_zeta_i1; scalar_tac + have h_zeta_eq2 : zeta_i2 = 7#usize := by + have := h_zeta_i2; scalar_tac + have h_zeta_eq3 : zeta_i3 = 15#usize := by + have := h_zeta_i3; scalar_tac + have h_zeta_eq4 : zeta_i4 = 31#usize := by + have := h10_zout; scalar_tac + have h_zeta_eq5 : zeta_i5 = 63#usize := by + have := h13_zout; scalar_tac + rw [h_re8_eq, h16_fc, h13_fc, h10_fc, h7_fc, h4_fc, h2_fc, h1_fc, + h_zeta_eq1, h_zeta_eq2, h_zeta_eq3, h_zeta_eq4, h_zeta_eq5] + + + + +end libcrux_iot_ml_kem.Ntt diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/NttDrivers.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/NttDrivers.lean new file mode 100644 index 00000000..88aef6c0 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/NttDrivers.lean @@ -0,0 +1,4383 @@ +/- + # `Equivalence/L3_NTTDrivers.lean` — Layer 3 NTT driver-loop Triples. + + L3.x Triples for the `ntt_at_layer_N` driver loops in `ntt.rs`: + + - **L3.1 `ntt_at_layer_1_spec`** — innermost layer: a 16-iter loop over + a `PolynomialRingElement`'s 16 PortableVectors, each call dispatched + via the trait's `ntt_layer_1_step` (which forwards to L2.2's + `vector.portable.ntt.ntt_layer_1_step`). Per-coefficient bound goes + `7·3328 → 8·3328`; `zeta_i.val : 63 → 127`. + - **L3.2 `ntt_at_layer_2_spec`** — 2 ζ lookups/iter, dispatches + `ntt_layer_2_step`. Bound `6·3328 → 7·3328`; `zeta_i : 31 → 63`. + - **L3.3 `ntt_at_layer_3_spec`** — 1 ζ lookup/iter, dispatches + `ntt_layer_3_step`. Bound `5·3328 → 6·3328`; `zeta_i : 15 → 31`. + + Specialised to `Vector := PortableVector` with the concrete + `Libcrux_iot_ml_kemVectorTraitsOperations` instance. The instance's + `ntt_layer_N_step` field reduces (via `@[reducible]`) to + `…Operations.ntt_layer_N_step`, which is itself a thin wrapper for + `vector.portable.ntt.ntt_layer_N_step` — L2.2 / L2.3 / L2.4 fire directly. + + 1516-1525) for F*-port references. +-/ +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Polynomial.PolyOps + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.Polynomial.NttDrivers +open libcrux_iot_ml_kem.Polynomial.PolyOps libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Result ControlFlow Std.Do +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper + +/-! ## Inhabited instances — needed for `.val[j]!` projections. + +`Std.Array α n` uses `List.get!` under the hood, which requires +`Inhabited α`. The L2/L1 layers don't trigger this because they only +project into `Array I16 16`. The L3 layer projects into `Array +PortableVector 16` and `Array (PolynomialRingElement PortableVector) K`, +so we register the canonical zero-witness instances locally. -/ + +instance : Inhabited libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + ⟨{ elements := Std.Array.make 16#usize (List.replicate 16 (0#i16 : Std.I16)) + (by simp) }⟩ + +instance {Vector : Type} [Inhabited Vector] {K : Std.Usize} : + Inhabited (libcrux_iot_ml_kem.polynomial.PolynomialRingElement Vector) := + ⟨{ coefficients := Std.Array.make 16#usize (List.replicate 16 default) (by simp) }⟩ + +/-! ## Local helpers — Triple ↔ Result.ok bridges, pure-prop holds. -/ + +private theorem triple_of_ok_l3 {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +private theorem triple_exists_ok_l3 {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +private theorem pure_prop_holds_l3 {P : Prop} (h : P) : (pure P : Result Prop).holds := by + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, WP.wp]; intro _; exact h + +private theorem of_pure_prop_holds_l3 {P : Prop} + (h : (pure P : Result Prop).holds) : P := by + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, WP.wp] at h; exact h trivial + +/-! ## Small `Usize.add` helper — produces `.val`-form equations. -/ + +private theorem usize_add_ok_eq (x y : Std.Usize) + (h_max : x.val + y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x + y : Result Std.Usize) = .ok z ∧ z.val = x.val + y.val := by + have hT := Std.Usize.add_spec h_max + -- hT : x + y ⦃ z => (↑z : Nat) = ↑x + ↑y ⦄ — this is `WP.spec`, not Triple. + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + refine ⟨z, h_eq, ?_⟩ + show z.val = x.val + y.val + exact h_v + +/-! ## `polynomial.zeta_spec` helper + +The `ZETAS_TIMES_MONTGOMERY_R` table has 128 `Std.I16` entries. Each is +in absolute value at most 1664 (in fact, all entries here are < 1700; +each fits in a Montgomery-reduced field element). `polynomial.zeta i` +performs a single bounded `Array.index_usize` on that table. + +We expose this through a single decidable bound: the table's underlying +`.val` (a 128-element `List Std.I16`) has every element ≤ 1664 in +absolute value. After unsealing the `@[irreducible]` table this is a +finite check that `decide` discharges. +-/ + +unseal libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R in +private theorem ZETAS_TIMES_MONTGOMERY_R_bound : + ∀ i : Nat, i < 128 → + ((libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R).val[i]!).val.natAbs ≤ 1664 := by + intro i hi + interval_cases i <;> decide + +@[spec] +theorem polynomial.zeta_spec (i : Std.Usize) (hi : i.val < 128) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.zeta i + ⦃ ⇓ r => ⌜ r.val.natAbs ≤ 1664 ⌝ ⦄ := by + -- `polynomial.zeta i = Array.index_usize ZETAS_TIMES_MONTGOMERY_R i`. + have h_len : + (libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R).length = 128 := + Std.Array.length_eq _ + have h_idx : + Aeneas.Std.Array.index_usize + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R i + = .ok ((libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R).val[i.val]!) := + array_index_usize_ok_eq _ i (by rw [h_len]; exact hi) + have h_ok : + libcrux_iot_ml_kem.polynomial.zeta i + = .ok ((libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R).val[i.val]!) := by + unfold libcrux_iot_ml_kem.polynomial.zeta + rw [h_idx] + exact triple_of_ok_l3 h_ok (ZETAS_TIMES_MONTGOMERY_R_bound i.val hi) + +/-! ## L3.1 — `ntt_at_layer_1_spec` + +Driver loop: 16 iterations over `re.coefficients`. Each iteration reads +`re.coefficients[k]` (a `PortableVector`), looks up four ζ-values via +`polynomial.zeta` (indices `zeta_i.val + 1 .. zeta_i.val + 4`), +dispatches `OpsInst.ntt_layer_1_step`, and writes back. `zeta_i.val` +advances by 4; the bound per coefficient goes `7·3328 → 8·3328`. + +We specialise to `Vector := PortableVector` and the concrete trait +instance. The `@[reducible]` instance field reduces +`OpsInst.ntt_layer_1_step a z0 z1 z2 z3` to +`vector.portable.ntt.ntt_layer_1_step a z0 z1 z2 z3` (mod a trivial +`Result.ok` wrap), which is L2.2's target. + +Loop invariant after `k` iterations (`k.val ∈ [0, 16]`), state +`(cur_zeta_i, cur_re)`: + - `cur_zeta_i.val = 63 + 4 * k.val` + - For `j < k.val`, all 16 elements of + `cur_re.coefficients[j]` are bounded by `8 * 3328`. + - For `j ≥ k.val`, `cur_re.coefficients[j] = re.coefficients[j]` + (so per `h_pre`, all 16 elements are bounded by `7 * 3328`). -/ + +namespace Layer1 + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +/-- Step-local accumulator type — explicitly named to keep `loop_range_spec_usize`'s + `β` parameter mounted to a concrete type for inference. -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- Loop invariant. -/ +def inv + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = 63 + 4 * k.val + ∧ (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 8 * 3328) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!)) + +/-- Step post (named to keep the `match` constant canonical across sites). -/ +def step_post + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv re iter'.start acc').holds + | .done y => (inv re 16#usize y).holds + +end Layer1 + +/-- Per-iteration step lemma: each body call advances `zeta_i` by 4 and + transforms `re.coefficients[k]` from a `≤ 7·3328` PortableVector to + a `≤ 8·3328` one (preserving all other indices and the bound chain). -/ +private theorem ntt_at_layer_1_step_lemma + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 7 * 3328) + (acc : Layer1.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_zeta_acc : acc.1.val = 63 + 4 * k.val) + (h_acc_done : ∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 8 * 3328) + (h_acc_undone : ∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer1.step_post re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some round = k branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + -- 1) `zeta_i + 1`. + -- Bound chain: acc.1.val = 63 + 4*k.val with k.val < 16, so + -- acc.1.val ≤ 123 and acc.1.val + 4 ≤ 127. Each Usize.add stays + -- well within Std.Usize.max (≥ 2^32 - 1). + have h_acc1_lt : acc.1.val ≤ 123 := by rw [h_zeta_acc]; omega + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_um2 : (2#usize : Std.Usize).val = 2 := rfl + have h_um3 : (3#usize : Std.Usize).val = 3 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + usize_add_ok_eq acc.1 1#usize h_z_max + -- 2) `Array.index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx] + rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + -- zi1.val arithmetic: zi1.val = acc.1.val + 1 = 64 + 4*k.val ≤ 124. + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by rw [h_zi1_val_arith]; omega + -- 3) `polynomial.zeta zi1`. + obtain ⟨z1, h_z1_eq, h_z1_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi1 h_zi1_lt) + -- 4) `zi1 + 1`. + have h_zi3_max : zi1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi3, h_zi3_eq, h_zi3_val⟩ := + usize_add_ok_eq zi1 1#usize h_zi3_max + have h_zi3_val_arith : zi3.val = acc.1.val + 2 := by + rw [h_zi3_val, h_um, h_zi1_val_arith] + have h_zi3_lt : zi3.val < 128 := by rw [h_zi3_val_arith]; omega + -- 5) `polynomial.zeta zi3`. + obtain ⟨z2, h_z2_eq, h_z2_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi3 h_zi3_lt) + -- 6) `zi1 + 2`. + have h_zi5_max : zi1.val + (2#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um2]; scalar_tac + obtain ⟨zi5, h_zi5_eq, h_zi5_val⟩ := + usize_add_ok_eq zi1 2#usize h_zi5_max + have h_zi5_val_arith : zi5.val = acc.1.val + 3 := by + rw [h_zi5_val, h_um2, h_zi1_val_arith] + have h_zi5_lt : zi5.val < 128 := by rw [h_zi5_val_arith]; omega + -- 7) `polynomial.zeta zi5`. + obtain ⟨z3, h_z3_eq, h_z3_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi5 h_zi5_lt) + -- 8) `zi1 + 3`. + have h_zi7_max : zi1.val + (3#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um3]; scalar_tac + obtain ⟨zi7, h_zi7_eq, h_zi7_val⟩ := + usize_add_ok_eq zi1 3#usize h_zi7_max + have h_zi7_val_arith : zi7.val = acc.1.val + 4 := by + rw [h_zi7_val, h_um3, h_zi1_val_arith] + have h_zi7_lt : zi7.val < 128 := by rw [h_zi7_val_arith]; omega + -- 9) `polynomial.zeta zi7`. + obtain ⟨z4, h_z4_eq, h_z4_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi7 h_zi7_lt) + -- 10) `OpsInst.ntt_layer_1_step t z1 z2 z3 z4`. Reduces via the + -- @[reducible] instance to `vector.portable.ntt.ntt_layer_1_step`, + -- to which L2.2 applies. Pre: t's elements ≤ 7·3328 (it's + -- `re.coefficients[k]` via h_acc_undone + h_pre). + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 7 * 3328 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + obtain ⟨t1, h_t1_eq, h_t1_bd⟩ := + triple_exists_ok_l3 (ntt_layer_1_step_spec t z1 z2 z3 z4 + h_z1_bd h_z2_bd h_z3_bd h_z4_bd h_t_bd) + -- Set the next-state values. + set a : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.2.coefficients.set k t1 with ha_def + set acc' : Layer1.Acc := (zi7, { coefficients := a }) with hacc'_def + -- Compose the whole body into one `.ok` equation. + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + -- Force complete let-Prod-match-on-Some normalization via plain + -- `simp` (NOT `simp only`) — this engages β-iota reductions that + -- `simp only` skips. Compose all step hypotheses simultaneously. + simp [bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq, h_zi3_eq, + h_z2_eq, h_zi5_eq, h_z3_eq, h_zi7_eq, h_z4_eq] + -- After simp, only the final `OpsInst.ntt_layer_1_step` remains + -- (the trait instance's outer `do`-wrapper is `@[reducible]` and + -- forwards to `vector.portable.ntt.ntt_layer_1_step`; simp doesn't + -- unfold by default). Unfold the instance projection definitionally, + -- then close via `h_t1_eq`. + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step t z1 z2 z3 z4 + ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi7, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq] + rfl + apply triple_of_ok_l3 h_body + show Layer1.step_post re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold Layer1.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + -- Now: invariant at (s, acc'). + apply pure_prop_holds_l3 + -- Three conjuncts of Layer1.inv at (s, acc'). + refine ⟨?_, ?_, ?_⟩ + · -- acc'.1.val = zi7.val = 63 + 4 * s.val. + show zi7.val = 63 + 4 * s.val + rw [h_zi7_val_arith, h_zeta_acc, hs_val]; ring + · -- All j < s.val are bounded by 8*3328. + intro j hj ℓ hℓ + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: unchanged by the set, use h_acc_done. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show ((acc.2.coefficients.set k t1).val[j]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k ℓ hℓ + · -- j = k.val: it's t1. + subst hj_eq_k + have h_lt' : k.val < acc.2.coefficients.length := by + rw [h_coef_len]; exact hk_16 + have h_set_eq : + (acc.2.coefficients.set k t1)[k.val]! = t1 := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, h_lt'⟩ + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show ((acc.2.coefficients.set k t1).val[k.val]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_eq_val] + exact h_t1_bd ℓ hℓ + · -- All j ≥ s.val are unchanged. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- None branch (k ≥ 16). + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + -- Need (acc.1, acc.2) = acc as a pair; for a Prod, this is definitional. + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_l3 h_body + show Layer1.step_post re k (.done acc) + unfold Layer1.step_post + show (Layer1.inv re 16#usize acc).holds + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · rw [hk_eq] at h_zeta_acc; rw [show (16#usize : Std.Usize).val = 16 from rfl] + exact h_zeta_acc + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt; rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + +set_option maxHeartbeats 16000000 in +@[spec] +theorem ntt_at_layer_1_spec + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_coefficient_bound : Std.Usize) + (h_zeta : zeta_i.val = 63) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 7 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_1 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i re initial_coefficient_bound + ⦃ ⇓ p => ⌜ p.1.val = 127 + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 8 * 3328 ⌝ ⦄ := by + -- Reduce the top wrapper to the inner loop. + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1 + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + iter1 acc1.1 acc1.2) + (β := Layer1.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer1.inv re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_l3 ⟨by rw [h_zeta]; rfl, + fun j hj _ _ => absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_zeta_eq, h_done, _h_undone⟩ := of_pure_prop_holds_l3 h + refine ⟨?_, ?_⟩ + · have h16 : (16#usize : Std.Usize).val = 16 := rfl + rw [h16] at h_zeta_eq; omega + · intro i hi j hj + apply h_done i (by rw [show (16#usize : Std.Usize).val = 16 from rfl]; exact hi) j hj + · -- Step lemma application. + intro acc k h_ge h_le hinv + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_l3 hinv + have h_step := ntt_at_layer_1_step_lemma re h_pre acc k h_le h_zeta_acc + h_acc_done h_acc_undone + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer1.step_post re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer1.step_post] using hP + · have hP : Layer1.step_post re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer1.step_post] using hP + +/-! ## L3.1.B — `ntt_at_layer_1_spec_B` + +Nat-`bnd`-parameterised mirror of `ntt_at_layer_1_spec` (L3.1). Same +driver loop (16 iterations) and same ζ-thread (`63 → 127`); the +per-coefficient input bound `7 * 3328` is replaced by a `Nat` +parameter `bnd` and the output bound becomes `bnd + 3328`. The +precondition `bnd ≤ 8 * 3328` matches the upstream +`ntt_layer_1_step_spec_bnd` requirement. + +The existing `ntt_at_layer_1_spec` is the `bnd = 7 * 3328` +instantiation and is left untouched. -/ + +namespace Layer1Bounded + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- Loop invariant (Nat-bnd parameterised). Done lanes ≤ `bnd + 3328`; + undone lanes still equal `re.coefficients[j]` (per `h_pre`, these + are ≤ `bnd`). -/ +def inv_B + (bnd : Nat) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = 63 + 4 * k.val + ∧ (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!)) + +def step_post_B + (bnd : Nat) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv_B bnd re iter'.start acc').holds + | .done y => (inv_B bnd re 16#usize y).holds + +end Layer1Bounded + +/-- Per-iteration step lemma (Nat-bnd parameterised). Mirrors + `ntt_at_layer_1_step_lemma` but threads `bnd` and dispatches via + `ntt_layer_1_step_spec_bnd`. -/ +private theorem ntt_at_layer_1_step_lemma_B + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd) + (acc : Layer1Bounded.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_zeta_acc : acc.1.val = 63 + 4 * k.val) + (h_acc_done : ∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328) + (h_acc_undone : ∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer1Bounded.step_post_B bnd re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some round = k branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + -- 1) `zeta_i + 1`. + have h_acc1_lt : acc.1.val ≤ 123 := by rw [h_zeta_acc]; omega + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_um2 : (2#usize : Std.Usize).val = 2 := rfl + have h_um3 : (3#usize : Std.Usize).val = 3 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + usize_add_ok_eq acc.1 1#usize h_z_max + -- 2) `Array.index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx] + rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by rw [h_zi1_val_arith]; omega + -- 3) `polynomial.zeta zi1`. + obtain ⟨z1, h_z1_eq, h_z1_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi1 h_zi1_lt) + -- 4) `zi1 + 1`. + have h_zi3_max : zi1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi3, h_zi3_eq, h_zi3_val⟩ := + usize_add_ok_eq zi1 1#usize h_zi3_max + have h_zi3_val_arith : zi3.val = acc.1.val + 2 := by + rw [h_zi3_val, h_um, h_zi1_val_arith] + have h_zi3_lt : zi3.val < 128 := by rw [h_zi3_val_arith]; omega + -- 5) `polynomial.zeta zi3`. + obtain ⟨z2, h_z2_eq, h_z2_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi3 h_zi3_lt) + -- 6) `zi1 + 2`. + have h_zi5_max : zi1.val + (2#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um2]; scalar_tac + obtain ⟨zi5, h_zi5_eq, h_zi5_val⟩ := + usize_add_ok_eq zi1 2#usize h_zi5_max + have h_zi5_val_arith : zi5.val = acc.1.val + 3 := by + rw [h_zi5_val, h_um2, h_zi1_val_arith] + have h_zi5_lt : zi5.val < 128 := by rw [h_zi5_val_arith]; omega + -- 7) `polynomial.zeta zi5`. + obtain ⟨z3, h_z3_eq, h_z3_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi5 h_zi5_lt) + -- 8) `zi1 + 3`. + have h_zi7_max : zi1.val + (3#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um3]; scalar_tac + obtain ⟨zi7, h_zi7_eq, h_zi7_val⟩ := + usize_add_ok_eq zi1 3#usize h_zi7_max + have h_zi7_val_arith : zi7.val = acc.1.val + 4 := by + rw [h_zi7_val, h_um3, h_zi1_val_arith] + have h_zi7_lt : zi7.val < 128 := by rw [h_zi7_val_arith]; omega + -- 9) `polynomial.zeta zi7`. + obtain ⟨z4, h_z4_eq, h_z4_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi7 h_zi7_lt) + -- 10) `OpsInst.ntt_layer_1_step t z1 z2 z3 z4` — `_bnd` form. + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ bnd := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + obtain ⟨t1, h_t1_eq, h_t1_bd⟩ := + triple_exists_ok_l3 (ntt_layer_1_step_spec_bnd t z1 z2 z3 z4 bnd h_bnd + h_z1_bd h_z2_bd h_z3_bd h_z4_bd h_t_bd) + -- Set the next-state values. + set a : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.2.coefficients.set k t1 with ha_def + set acc' : Layer1Bounded.Acc := (zi7, { coefficients := a }) with hacc'_def + -- Compose the whole body into one `.ok` equation. + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq, h_zi3_eq, + h_z2_eq, h_zi5_eq, h_z3_eq, h_zi7_eq, h_z4_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step t z1 z2 z3 z4 + ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi7, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq] + rfl + apply triple_of_ok_l3 h_body + show Layer1Bounded.step_post_B bnd re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold Layer1Bounded.step_post_B + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · show zi7.val = 63 + 4 * s.val + rw [h_zi7_val_arith, h_zeta_acc, hs_val]; ring + · intro j hj ℓ hℓ + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show ((acc.2.coefficients.set k t1).val[j]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k ℓ hℓ + · subst hj_eq_k + have h_lt' : k.val < acc.2.coefficients.length := by + rw [h_coef_len]; exact hk_16 + have h_set_eq : + (acc.2.coefficients.set k t1)[k.val]! = t1 := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, h_lt'⟩ + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show ((acc.2.coefficients.set k t1).val[k.val]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_eq_val] + exact h_t1_bd ℓ hℓ + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- None branch (k ≥ 16). + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_l3 h_body + show Layer1Bounded.step_post_B bnd re k (.done acc) + unfold Layer1Bounded.step_post_B + show (Layer1Bounded.inv_B bnd re 16#usize acc).holds + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · rw [hk_eq] at h_zeta_acc; rw [show (16#usize : Std.Usize).val = 16 from rfl] + exact h_zeta_acc + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt; rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + +set_option maxHeartbeats 16000000 in +@[spec] +theorem ntt_at_layer_1_spec_B + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_coefficient_bound : Std.Usize) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (h_zeta : zeta_i.val = 63) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_1 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i re initial_coefficient_bound + ⦃ ⇓ p => ⌜ p.1.val = 127 + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1 + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_1_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + iter1 acc1.1 acc1.2) + (β := Layer1Bounded.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer1Bounded.inv_B bnd re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_l3 ⟨by rw [h_zeta]; rfl, + fun j hj _ _ => absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_zeta_eq, h_done, _h_undone⟩ := of_pure_prop_holds_l3 h + refine ⟨?_, ?_⟩ + · have h16 : (16#usize : Std.Usize).val = 16 := rfl + rw [h16] at h_zeta_eq; omega + · intro i hi j hj + apply h_done i (by rw [show (16#usize : Std.Usize).val = 16 from rfl]; exact hi) j hj + · intro acc k h_ge h_le hinv + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_l3 hinv + have h_step := ntt_at_layer_1_step_lemma_B re bnd h_bnd h_pre acc k h_le h_zeta_acc + h_acc_done h_acc_undone + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer1Bounded.step_post_B bnd re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer1Bounded.step_post_B] using hP + · have hP : Layer1Bounded.step_post_B bnd re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer1Bounded.step_post_B] using hP + +/-! ## L3.2 — `ntt_at_layer_2_spec` + +Driver loop: 16 iterations over `re.coefficients`. Each iteration reads +`re.coefficients[k]` (a `PortableVector`), looks up two ζ-values via +`polynomial.zeta` (indices `zeta_i.val + 1` and `zeta_i.val + 2`), +dispatches `OpsInst.ntt_layer_2_step`, and writes back. `zeta_i.val` +advances by 2 per iter (state stores `zeta_i1 + 1 = zeta_i + 2`); the +bound per coefficient goes `6·3328 → 7·3328`. -/ + +namespace Layer2 + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +def inv + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = 31 + 2 * k.val + ∧ (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 7 * 3328) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!)) + +def step_post + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv re iter'.start acc').holds + | .done y => (inv re 16#usize y).holds + +end Layer2 + +private theorem ntt_at_layer_2_step_lemma + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 6 * 3328) + (acc : Layer2.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_zeta_acc : acc.1.val = 31 + 2 * k.val) + (h_acc_done : ∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 7 * 3328) + (h_acc_undone : ∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer2.step_post re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some round = k branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + -- 1) `zeta_i + 1`. Bound chain: acc.1.val = 31 + 2*k.val with + -- k.val < 16, so acc.1.val ≤ 61 and acc.1.val + 2 ≤ 63. + have h_acc1_lt : acc.1.val ≤ 61 := by rw [h_zeta_acc]; omega + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_um2 : (2#usize : Std.Usize).val = 2 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + usize_add_ok_eq acc.1 1#usize h_z_max + -- 2) `Array.index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx] + rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + -- zi1.val = acc.1.val + 1 = 32 + 2*k.val ≤ 62 < 128. + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by rw [h_zi1_val_arith]; omega + -- 3) `polynomial.zeta zi1`. + obtain ⟨z1, h_z1_eq, h_z1_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi1 h_zi1_lt) + -- 4) `zi1 + 1`. + have h_zi3_max : zi1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi3, h_zi3_eq, h_zi3_val⟩ := + usize_add_ok_eq zi1 1#usize h_zi3_max + have h_zi3_val_arith : zi3.val = acc.1.val + 2 := by + rw [h_zi3_val, h_um, h_zi1_val_arith] + have h_zi3_lt : zi3.val < 128 := by rw [h_zi3_val_arith]; omega + -- 5) `polynomial.zeta zi3`. + obtain ⟨z2, h_z2_eq, h_z2_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi3 h_zi3_lt) + -- 6) `OpsInst.ntt_layer_2_step t z1 z2`. L2.3 fires after instance + -- reduces. Pre: t's elements ≤ 6·3328. + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 6 * 3328 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + obtain ⟨t1, h_t1_eq, h_t1_bd⟩ := + triple_exists_ok_l3 (ntt_layer_2_step_spec t z1 z2 h_z1_bd h_z2_bd h_t_bd) + -- Next-state values: state stores `zi3` (= zeta_i + 2). + set a : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.2.coefficients.set k t1 with ha_def + set acc' : Layer2.Acc := (zi3, { coefficients := a }) with hacc'_def + -- Compose the whole body into one `.ok` equation. + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq, h_zi3_eq, h_z2_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step t z1 z2 + ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi3, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq] + rfl + apply triple_of_ok_l3 h_body + show Layer2.step_post re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold Layer2.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · -- acc'.1.val = zi3.val = 31 + 2 * s.val. + show zi3.val = 31 + 2 * s.val + rw [h_zi3_val_arith, h_zeta_acc, hs_val]; ring + · intro j hj ℓ hℓ + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show ((acc.2.coefficients.set k t1).val[j]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k ℓ hℓ + · subst hj_eq_k + have h_lt' : k.val < acc.2.coefficients.length := by + rw [h_coef_len]; exact hk_16 + have h_set_eq : + (acc.2.coefficients.set k t1)[k.val]! = t1 := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, h_lt'⟩ + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show ((acc.2.coefficients.set k t1).val[k.val]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_eq_val] + exact h_t1_bd ℓ hℓ + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- None branch (k ≥ 16). + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_l3 h_body + show Layer2.step_post re k (.done acc) + unfold Layer2.step_post + show (Layer2.inv re 16#usize acc).holds + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · rw [hk_eq] at h_zeta_acc; rw [show (16#usize : Std.Usize).val = 16 from rfl] + exact h_zeta_acc + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt; rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + +set_option maxHeartbeats 16000000 in +@[spec] +theorem ntt_at_layer_2_spec + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_coefficient_bound : Std.Usize) + (h_zeta : zeta_i.val = 31) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 6 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_2 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i re initial_coefficient_bound + ⦃ ⇓ p => ⌜ p.1.val = 63 + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 7 * 3328 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2 + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + iter1 acc1.1 acc1.2) + (β := Layer2.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer2.inv re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_l3 ⟨by rw [h_zeta]; rfl, + fun j hj _ _ => absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_zeta_eq, h_done, _h_undone⟩ := of_pure_prop_holds_l3 h + refine ⟨?_, ?_⟩ + · have h16 : (16#usize : Std.Usize).val = 16 := rfl + rw [h16] at h_zeta_eq; omega + · intro i hi j hj + apply h_done i (by rw [show (16#usize : Std.Usize).val = 16 from rfl]; exact hi) j hj + · intro acc k h_ge h_le hinv + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_l3 hinv + have h_step := ntt_at_layer_2_step_lemma re h_pre acc k h_le h_zeta_acc + h_acc_done h_acc_undone + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer2.step_post re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer2.step_post] using hP + · have hP : Layer2.step_post re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer2.step_post] using hP + +/-! ## L3.2.B — `ntt_at_layer_2_spec_B` + +Nat-`bnd`-parameterised mirror of `ntt_at_layer_2_spec` (L3.2). Same +driver loop (16 iterations) and same ζ-thread (`31 → 63`); per-iter +two ζ lookups, dispatches `ntt_layer_2_step_spec_bnd`. Input bound +`6 * 3328` → `bnd`; output bound `7 * 3328` → `bnd + 3328`. -/ + +namespace Layer2Bounded + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +def inv_B + (bnd : Nat) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = 31 + 2 * k.val + ∧ (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!)) + +def step_post_B + (bnd : Nat) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv_B bnd re iter'.start acc').holds + | .done y => (inv_B bnd re 16#usize y).holds + +end Layer2Bounded + +private theorem ntt_at_layer_2_step_lemma_B + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd) + (acc : Layer2Bounded.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_zeta_acc : acc.1.val = 31 + 2 * k.val) + (h_acc_done : ∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328) + (h_acc_undone : ∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer2Bounded.step_post_B bnd re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some round = k branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + have h_acc1_lt : acc.1.val ≤ 61 := by rw [h_zeta_acc]; omega + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_um2 : (2#usize : Std.Usize).val = 2 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + usize_add_ok_eq acc.1 1#usize h_z_max + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx] + rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by rw [h_zi1_val_arith]; omega + obtain ⟨z1, h_z1_eq, h_z1_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi1 h_zi1_lt) + have h_zi3_max : zi1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi3, h_zi3_eq, h_zi3_val⟩ := + usize_add_ok_eq zi1 1#usize h_zi3_max + have h_zi3_val_arith : zi3.val = acc.1.val + 2 := by + rw [h_zi3_val, h_um, h_zi1_val_arith] + have h_zi3_lt : zi3.val < 128 := by rw [h_zi3_val_arith]; omega + obtain ⟨z2, h_z2_eq, h_z2_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi3 h_zi3_lt) + -- `OpsInst.ntt_layer_2_step t z1 z2` — `_bnd` form. + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ bnd := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + obtain ⟨t1, h_t1_eq, h_t1_bd⟩ := + triple_exists_ok_l3 (ntt_layer_2_step_spec_bnd t z1 z2 bnd h_bnd + h_z1_bd h_z2_bd h_t_bd) + set a : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.2.coefficients.set k t1 with ha_def + set acc' : Layer2Bounded.Acc := (zi3, { coefficients := a }) with hacc'_def + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq, h_zi3_eq, h_z2_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step t z1 z2 + ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi3, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq] + rfl + apply triple_of_ok_l3 h_body + show Layer2Bounded.step_post_B bnd re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold Layer2Bounded.step_post_B + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · show zi3.val = 31 + 2 * s.val + rw [h_zi3_val_arith, h_zeta_acc, hs_val]; ring + · intro j hj ℓ hℓ + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show ((acc.2.coefficients.set k t1).val[j]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k ℓ hℓ + · subst hj_eq_k + have h_lt' : k.val < acc.2.coefficients.length := by + rw [h_coef_len]; exact hk_16 + have h_set_eq : + (acc.2.coefficients.set k t1)[k.val]! = t1 := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, h_lt'⟩ + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show ((acc.2.coefficients.set k t1).val[k.val]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_eq_val] + exact h_t1_bd ℓ hℓ + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- None branch (k ≥ 16). + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_l3 h_body + show Layer2Bounded.step_post_B bnd re k (.done acc) + unfold Layer2Bounded.step_post_B + show (Layer2Bounded.inv_B bnd re 16#usize acc).holds + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · rw [hk_eq] at h_zeta_acc; rw [show (16#usize : Std.Usize).val = 16 from rfl] + exact h_zeta_acc + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt; rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + +set_option maxHeartbeats 16000000 in +@[spec] +theorem ntt_at_layer_2_spec_B + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_coefficient_bound : Std.Usize) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (h_zeta : zeta_i.val = 31) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_2 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i re initial_coefficient_bound + ⦃ ⇓ p => ⌜ p.1.val = 63 + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2 + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_2_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + iter1 acc1.1 acc1.2) + (β := Layer2Bounded.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer2Bounded.inv_B bnd re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_l3 ⟨by rw [h_zeta]; rfl, + fun j hj _ _ => absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_zeta_eq, h_done, _h_undone⟩ := of_pure_prop_holds_l3 h + refine ⟨?_, ?_⟩ + · have h16 : (16#usize : Std.Usize).val = 16 := rfl + rw [h16] at h_zeta_eq; omega + · intro i hi j hj + apply h_done i (by rw [show (16#usize : Std.Usize).val = 16 from rfl]; exact hi) j hj + · intro acc k h_ge h_le hinv + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_l3 hinv + have h_step := ntt_at_layer_2_step_lemma_B re bnd h_bnd h_pre acc k h_le h_zeta_acc + h_acc_done h_acc_undone + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer2Bounded.step_post_B bnd re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer2Bounded.step_post_B] using hP + · have hP : Layer2Bounded.step_post_B bnd re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer2Bounded.step_post_B] using hP + +/-! ## L3.3 — `ntt_at_layer_3_spec` + +Driver loop: 16 iterations over `re.coefficients`. Each iteration reads +`re.coefficients[k]` (a `PortableVector`), looks up one ζ-value via +`polynomial.zeta` (index `zeta_i.val + 1`), dispatches +`OpsInst.ntt_layer_3_step`, and writes back. `zeta_i.val` advances by 1 +per iter (state stores `zeta_i1 = zeta_i + 1`); the bound per +coefficient goes `5·3328 → 6·3328`. -/ + +namespace Layer3 + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +def inv + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = 15 + k.val + ∧ (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 6 * 3328) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!)) + +def step_post + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv re iter'.start acc').holds + | .done y => (inv re 16#usize y).holds + +end Layer3 + +private theorem ntt_at_layer_3_step_lemma + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 5 * 3328) + (acc : Layer3.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_zeta_acc : acc.1.val = 15 + k.val) + (h_acc_done : ∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 6 * 3328) + (h_acc_undone : ∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer3.step_post re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some round = k branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + -- 1) `zeta_i + 1`. Bound chain: acc.1.val = 15 + k.val with + -- k.val < 16, so acc.1.val ≤ 30 and acc.1.val + 1 ≤ 31. + have h_acc1_lt : acc.1.val ≤ 30 := by rw [h_zeta_acc]; omega + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + usize_add_ok_eq acc.1 1#usize h_z_max + -- 2) `Array.index_mut_usize re.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx] + rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + -- zi1.val = acc.1.val + 1 = 16 + k.val ≤ 31 < 128. + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by rw [h_zi1_val_arith]; omega + -- 3) `polynomial.zeta zi1`. + obtain ⟨z1, h_z1_eq, h_z1_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi1 h_zi1_lt) + -- 4) `OpsInst.ntt_layer_3_step t z1`. L2.4 fires after instance + -- reduces. Pre: t's elements ≤ 5·3328. + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 5 * 3328 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + obtain ⟨t1, h_t1_eq, h_t1_bd⟩ := + triple_exists_ok_l3 (ntt_layer_3_step_spec t z1 h_z1_bd h_t_bd) + -- Next-state values: state stores `zi1` (= zeta_i + 1). + set a : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.2.coefficients.set k t1 with ha_def + set acc' : Layer3.Acc := (zi1, { coefficients := a }) with hacc'_def + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step t z1 + ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi1, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq] + rfl + apply triple_of_ok_l3 h_body + show Layer3.step_post re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold Layer3.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · -- acc'.1.val = zi1.val = 15 + s.val. + show zi1.val = 15 + s.val + rw [h_zi1_val_arith, h_zeta_acc, hs_val]; ring + · intro j hj ℓ hℓ + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show ((acc.2.coefficients.set k t1).val[j]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k ℓ hℓ + · subst hj_eq_k + have h_lt' : k.val < acc.2.coefficients.length := by + rw [h_coef_len]; exact hk_16 + have h_set_eq : + (acc.2.coefficients.set k t1)[k.val]! = t1 := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, h_lt'⟩ + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show ((acc.2.coefficients.set k t1).val[k.val]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_eq_val] + exact h_t1_bd ℓ hℓ + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- None branch (k ≥ 16). + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_l3 h_body + show Layer3.step_post re k (.done acc) + unfold Layer3.step_post + show (Layer3.inv re 16#usize acc).holds + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · rw [hk_eq] at h_zeta_acc; rw [show (16#usize : Std.Usize).val = 16 from rfl] + exact h_zeta_acc + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt; rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + +set_option maxHeartbeats 16000000 in +@[spec] +theorem ntt_at_layer_3_spec + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_coefficient_bound : Std.Usize) + (h_zeta : zeta_i.val = 15) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 5 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_3 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i re initial_coefficient_bound + ⦃ ⇓ p => ⌜ p.1.val = 31 + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 6 * 3328 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3 + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + iter1 acc1.1 acc1.2) + (β := Layer3.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer3.inv re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_l3 ⟨by rw [h_zeta]; rfl, + fun j hj _ _ => absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_zeta_eq, h_done, _h_undone⟩ := of_pure_prop_holds_l3 h + refine ⟨?_, ?_⟩ + · have h16 : (16#usize : Std.Usize).val = 16 := rfl + rw [h16] at h_zeta_eq; omega + · intro i hi j hj + apply h_done i (by rw [show (16#usize : Std.Usize).val = 16 from rfl]; exact hi) j hj + · intro acc k h_ge h_le hinv + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_l3 hinv + have h_step := ntt_at_layer_3_step_lemma re h_pre acc k h_le h_zeta_acc + h_acc_done h_acc_undone + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer3.step_post re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer3.step_post] using hP + · have hP : Layer3.step_post re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer3.step_post] using hP + +/-! ## L3.3.B — `ntt_at_layer_3_spec_B` + +Nat-`bnd`-parameterised mirror of `ntt_at_layer_3_spec` (L3.3). Same +driver loop (16 iterations) and same ζ-thread (`15 → 31`); per-iter +single ζ lookup, dispatches `ntt_layer_3_step_spec_bnd`. Input bound +`5 * 3328` → `bnd`; output bound `6 * 3328` → `bnd + 3328`. -/ + +namespace Layer3Bounded + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +def inv_B + (bnd : Nat) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + acc.1.val = 15 + k.val + ∧ (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!)) + +def step_post_B + (bnd : Nat) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv_B bnd re iter'.start acc').holds + | .done y => (inv_B bnd re 16#usize y).holds + +end Layer3Bounded + +private theorem ntt_at_layer_3_step_lemma_B + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd) + (acc : Layer3Bounded.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_zeta_acc : acc.1.val = 15 + k.val) + (h_acc_done : ∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328) + (h_acc_undone : ∀ j : Nat, k.val ≤ j → j < 16 → + acc.2.coefficients.val[j]! = re.coefficients.val[j]!) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer3Bounded.step_post_B bnd re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.2.coefficients.length = 16 := + Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some round = k branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + have h_acc1_lt : acc.1.val ≤ 30 := by rw [h_zeta_acc]; omega + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_z_max : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um]; scalar_tac + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := + usize_add_ok_eq acc.1 1#usize h_z_max + have h_idx : + Aeneas.Std.Array.index_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!) := + array_index_usize_ok_eq acc.2.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.2.coefficients k + = .ok (acc.2.coefficients.val[k.val]!, acc.2.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx] + rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.2.coefficients.val[k.val]! with ht_def + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by rw [h_zi1_val_arith]; omega + obtain ⟨z1, h_z1_eq, h_z1_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zi1 h_zi1_lt) + -- `OpsInst.ntt_layer_3_step t z1` — `_bnd` form. + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.2.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ bnd := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + obtain ⟨t1, h_t1_eq, h_t1_bd⟩ := + triple_exists_ok_l3 (ntt_layer_3_step_spec_bnd t z1 bnd h_bnd h_z1_bd h_t_bd) + set a : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.2.coefficients.set k t1 with ha_def + set acc' : Layer3Bounded.Acc := (zi1, { coefficients := a }) with hacc'_def + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [bind_tc_ok, h_zi1_eq, h_imt_ok, h_z1_eq] + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step t z1 + ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + zi1, + ({ coefficients := acc.2.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq] + rfl + apply triple_of_ok_l3 h_body + show Layer3Bounded.step_post_B bnd re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold Layer3Bounded.step_post_B + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · show zi1.val = 15 + s.val + rw [h_zi1_val_arith, h_zeta_acc, hs_val]; ring + · intro j hj ℓ hℓ + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show ((acc.2.coefficients.set k t1).val[j]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k ℓ hℓ + · subst hj_eq_k + have h_lt' : k.val < acc.2.coefficients.length := by + rw [h_coef_len]; exact hk_16 + have h_set_eq : + (acc.2.coefficients.set k t1)[k.val]! = t1 := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.2.coefficients k k.val t1 + ⟨rfl, h_lt'⟩ + have h_set_eq_val : + (acc.2.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show ((acc.2.coefficients.set k t1).val[k.val]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_eq_val] + exact h_t1_bd ℓ hℓ + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : + (acc.2.coefficients.set k t1)[j]! = (acc.2.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.2.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.2.coefficients.set k t1).val[j]! = acc.2.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.2.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- None branch (k ≥ 16). + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_l3 h_body + show Layer3Bounded.step_post_B bnd re k (.done acc) + unfold Layer3Bounded.step_post_B + show (Layer3Bounded.inv_B bnd re 16#usize acc).holds + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · rw [hk_eq] at h_zeta_acc; rw [show (16#usize : Std.Usize).val = 16 from rfl] + exact h_zeta_acc + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt; rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + +set_option maxHeartbeats 16000000 in +@[spec] +theorem ntt_at_layer_3_spec_B + (zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (initial_coefficient_bound : Std.Usize) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (h_zeta : zeta_i.val = 15) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_3 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i re initial_coefficient_bound + ⦃ ⇓ p => ⌜ p.1.val = 31 + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3 + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_3_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + iter1 acc1.1 acc1.2) + (β := Layer3Bounded.Acc) + (zeta_i, re) + 0#usize 16#usize + (Layer3Bounded.inv_B bnd re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_l3 ⟨by rw [h_zeta]; rfl, + fun j hj _ _ => absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_zeta_eq, h_done, _h_undone⟩ := of_pure_prop_holds_l3 h + refine ⟨?_, ?_⟩ + · have h16 : (16#usize : Std.Usize).val = 16 := rfl + rw [h16] at h_zeta_eq; omega + · intro i hi j hj + apply h_done i (by rw [show (16#usize : Std.Usize).val = 16 from rfl]; exact hi) j hj + · intro acc k h_ge h_le hinv + obtain ⟨h_zeta_acc, h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_l3 hinv + have h_step := ntt_at_layer_3_step_lemma_B re bnd h_bnd h_pre acc k h_le h_zeta_acc + h_acc_done h_acc_undone + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer3Bounded.step_post_B bnd re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer3Bounded.step_post_B] using hP + · have hP : Layer3Bounded.step_post_B bnd re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer3Bounded.step_post_B] using hP + +/-! ## L3.5 — `ntt_at_layer_7_spec` + +Outermost layer of the forward NTT. No `zeta_i` (single fixed Montgomery +multiplier `-1600`). Driver loop runs 8 iterations (`step = 16/2 = 8`) +over the first half of `re.coefficients`; per iter touches lanes `j` +and `j+8`. Per-coefficient bound goes `3 → 4803` (= `3 + 1600·3`). + +Per-iter body (j = k.val, i = j+8, all reads/writes from re/acc): + scratch1 := re[i] + scratch2 := -1600 * scratch1 (L1.7) + re[i] := re[j] (lane swap) + t2 := re[j] + scratch2 (L1.1) + re[j] := t2 + t4 := - scratch2 (L1.2) + re[i] := t4 +So new re[j] = old re[j] + (-1600) * old re[i]; new re[i] = old re[j] - +(-1600) * old re[i] = old re[j] + 1600 * old re[i]. Both bounded by +3 + 4800 = 4803 in absolute value under |old re[*][ℓ]| ≤ 3. -/ + + +/-! ### Local helpers: `Usize.div` reduction + generic-`«end»` iter-next. -/ + +private theorem usize_div_ok_eq (x y : Std.Usize) (hy : y.val ≠ 0) : + ∃ z : Std.Usize, (x / y : Result Std.Usize) = .ok z ∧ z.val = x.val / y.val := by + have hT := Std.UScalar.div_spec x hy + obtain ⟨z, h_eq, h_v⟩ := hT + exact ⟨z, h_eq, h_v⟩ + +/-- `i.val < e.val`: `IteratorRange.next` returns `.ok (some i, iter')` with + `iter'.end = e` and `iter'.start.val = i.val + 1`. Generic version of + `iter_next_some_eq` (which is specialised to `«end» := 16#usize`). -/ +private theorem iter_next_some_eq_gen + (i e : Std.Usize) (h_lt : i.val < e.val) : + ∃ s : Std.Usize, s.val = i.val + 1 ∧ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := e } : CoreModels.core.ops.range.Range Std.Usize) + = .ok (some i, + ({ start := s, «end» := e } : CoreModels.core.ops.range.Range Std.Usize)) := by + have hT := IteratorRange_next_spec_usize i e + (Q := PostCond.noThrow fun (oi : Option Std.Usize × _) => ⌜ + ∃ s : Std.Usize, s.val = i.val + 1 + ∧ oi = (some i, + ({ start := s, «end» := e } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝) + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v, hveq, hP⟩ := triple_exists_ok_l3 hT + obtain ⟨s, hs_val, hpair⟩ := hP + exact ⟨s, hs_val, by rw [hveq, hpair]⟩ + +/-- `i.val ≥ e.val`: `IteratorRange.next` returns `.ok (none, _)`. Generic + version of `iter_next_none_eq`. -/ +private theorem iter_next_none_eq_gen + (i e : Std.Usize) (h_ge : i.val ≥ e.val) : + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := e } : CoreModels.core.ops.range.Range Std.Usize) + = .ok ((none : Option Std.Usize), + ({ start := i, «end» := e } + : CoreModels.core.ops.range.Range Std.Usize)) := by + have hT := IteratorRange_next_spec_usize i e + (Q := PostCond.noThrow fun (oi : Option Std.Usize × _) => ⌜ + oi = ((none : Option Std.Usize), + ({ start := i, «end» := e } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝) + (fun hlt => absurd hlt (Nat.not_lt.mpr h_ge)) + (fun _ => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v, hveq, hP⟩ := triple_exists_ok_l3 hT + rw [hveq, hP] + +namespace Layer7 + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +/-- Step-local accumulator: a `PolynomialRingElement × scratch`. -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector × + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- Loop invariant after `k` iterations (`k.val ∈ [0, 8]`): + - Lanes `j ∈ [0, k.val)` are processed: bounded by 4803. + - Lanes `j ∈ [8, 8 + k.val)` are processed: bounded by 4803. + - Lanes `j ∈ [k.val, 8)` are untouched: bounded by 3 (from `re`). + - Lanes `j ∈ [8 + k.val, 16)` are untouched: bounded by 3 (from `re`). -/ +def inv + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4803) + ∧ (∀ j : Nat, 8 ≤ j → j < 8 + k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4803) + ∧ (∀ j : Nat, k.val ≤ j → j < 8 → + acc.1.coefficients.val[j]! = re.coefficients.val[j]!) + ∧ (∀ j : Nat, 8 + k.val ≤ j → j < 16 → + acc.1.coefficients.val[j]! = re.coefficients.val[j]!)) + +/-- Per-iter step post. -/ +def step_post + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (8#usize : Std.Usize).val ∧ iter'.«end» = 8#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv re iter'.start acc').holds + | .done y => (inv re 8#usize y).holds + +end Layer7 + +set_option maxHeartbeats 16000000 in +private theorem ntt_at_layer_7_step_lemma + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3) + (acc : Layer7.Acc) + (k : Std.Usize) (h_le : k.val ≤ (8#usize : Std.Usize).val) + (h_done_lo : ∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4803) + (h_done_hi : ∀ j : Nat, 8 ≤ j → j < 8 + k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4803) + (h_undone_lo : ∀ j : Nat, k.val ≤ j → j < 8 → + acc.1.coefficients.val[j]! = re.coefficients.val[j]!) + (h_undone_hi : ∀ j : Nat, 8 + k.val ≤ j → j < 16 → + acc.1.coefficients.val[j]! = re.coefficients.val[j]!) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + 8#usize { start := k, «end» := 8#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer7.step_post re k r ⌝ ⦄ := by + have h8 : (8#usize : Std.Usize).val = 8 := rfl + have h_coef_len : acc.1.coefficients.length = 16 := + Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + by_cases h_lt : k.val < (8#usize : Std.Usize).val + · -- Some round = k branch. + have hk_8 : k.val < 8 := by rw [h8] at h_lt; exact h_lt + have hk_16 : k.val < 16 := by omega + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq_gen k 8#usize h_lt + -- 1) `i = j + 8`: where j = k. + have h_um8 : (8#usize : Std.Usize).val = 8 := rfl + have h_i_max : k.val + (8#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um8]; scalar_tac + obtain ⟨i, h_i_eq, h_i_val⟩ := usize_add_ok_eq k 8#usize h_i_max + have h_i_val_arith : i.val = k.val + 8 := by rw [h_i_val, h_um8] + have h_i_lt_16 : i.val < 16 := by rw [h_i_val_arith]; omega + have h_i_lt_coef : i.val < acc.1.coefficients.length := by rw [h_coef_len]; exact h_i_lt_16 + -- 2) Read scratch1 = acc.1[i]. + have h_idx_i : + Aeneas.Std.Array.index_usize acc.1.coefficients i + = .ok (acc.1.coefficients.val[i.val]!) := + array_index_usize_ok_eq acc.1.coefficients i h_i_lt_coef + have h_acc_i_eq : acc.1.coefficients.val[i.val]! = re.coefficients.val[i.val]! := by + apply h_undone_hi i.val + · rw [h_i_val_arith]; omega + · exact h_i_lt_16 + have h_scratch1_bd : ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs ≤ 3 := by + intro ℓ hℓ + rw [h_acc_i_eq] + exact h_pre i.val h_i_lt_16 ℓ hℓ + -- 3) scratch2 = multiply_by_constant scratch1 (-1600). L1.7. + have h_neg1600_val : ((-1600)#i16 : Std.I16).val = -1600 := by decide + have h_mul_pre : ∀ ℓ : Nat, ℓ < 16 → + (((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val + * ((-1600)#i16 : Std.I16).val : Int).natAbs ≤ 2 ^ 15 - 1 := by + intro ℓ hℓ + have h_x_abs : ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs ≤ 3 := + h_scratch1_bd ℓ hℓ + rw [h_neg1600_val] + have h_abs_mul : (((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val * (-1600) : Int).natAbs + = ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs * 1600 := by + rw [Int.natAbs_mul]; rfl + rw [h_abs_mul] + have h_mul : ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs * 1600 + ≤ 3 * 1600 := + Nat.mul_le_mul_right 1600 h_x_abs + omega + obtain ⟨scratch2, h_scratch2_eq, h_scratch2_post⟩ := + triple_exists_ok_l3 (multiply_by_constant_spec (acc.1.coefficients.val[i.val]!) + ((-1600)#i16 : Std.I16) h_mul_pre) + have h_scratch2_bd : ∀ ℓ : Nat, ℓ < 16 → + (scratch2.elements.val[ℓ]!).val + = (acc.1.coefficients.val[i.val]!).elements.val[ℓ]!.val * (-1600 : Int) + ∧ (scratch2.elements.val[ℓ]!).val.natAbs ≤ 4800 := by + intro ℓ hℓ + have h_per := h_scratch2_post ℓ hℓ + have h_v_eq : (scratch2.elements.val[ℓ]!).val + = (acc.1.coefficients.val[i.val]!).elements.val[ℓ]!.val * (-1600 : Int) := by + rw [h_per.1, h_neg1600_val] + refine ⟨h_v_eq, ?_⟩ + have h_x_abs : ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs ≤ 3 := + h_scratch1_bd ℓ hℓ + have h_abs_eq : (scratch2.elements.val[ℓ]!).val.natAbs + = ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs * 1600 := by + rw [h_v_eq, Int.natAbs_mul]; rfl + rw [h_abs_eq] + have h_mul : ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs * 1600 + ≤ 3 * 1600 := + Nat.mul_le_mul_right 1600 h_x_abs + omega + -- 4) t = re.coefficients[j] = acc.1.coef[k]! + have h_k_lt_coef : k.val < acc.1.coefficients.length := by rw [h_coef_len]; exact hk_16 + have h_idx_k : + Aeneas.Std.Array.index_usize acc.1.coefficients k + = .ok (acc.1.coefficients.val[k.val]!) := + array_index_usize_ok_eq acc.1.coefficients k h_k_lt_coef + have h_acc_k_eq : acc.1.coefficients.val[k.val]! = re.coefficients.val[k.val]! := + h_undone_lo k.val (Nat.le_refl _) hk_8 + -- Per-element bound at lane k: ≤ 3 from h_pre via h_acc_k_eq. + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[k.val]!).elements.val[ℓ]!).val.natAbs ≤ 3 := by + intro ℓ hℓ + rw [h_acc_k_eq] + exact h_pre k.val hk_16 ℓ hℓ + -- Set t (for readability in downstream bounds). + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.1.coefficients.val[k.val]! with ht_def + -- 5) a = acc.1.coef.set i t. + have h_upd_i : + Aeneas.Std.Array.update acc.1.coefficients i t + = .ok (acc.1.coefficients.set i t) := + array_update_ok_eq acc.1.coefficients i t h_i_lt_coef + set a : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.1.coefficients.set i t with ha_def + -- 6) (t1, back1) = index_mut_usize a j. + have h_a_k_lt : k.val < a.length := by + change k.val < (acc.1.coefficients.set i t).length + rw [Std.Array.set_length]; rw [h_coef_len]; exact hk_16 + have h_a_k_idx : + Aeneas.Std.Array.index_usize a k = .ok (a.val[k.val]!) := + array_index_usize_ok_eq a k h_a_k_lt + have h_imt_a_k : + Aeneas.Std.Array.index_mut_usize a k = .ok (a.val[k.val]!, a.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_a_k_idx]; rfl + have h_k_ne_i : k.val ≠ i.val := by rw [h_i_val_arith]; omega + have h_i_ne_k : i.val ≠ k.val := fun h => h_k_ne_i h.symm + have h_a_k_val_eq : a.val[k.val]! = acc.1.coefficients.val[k.val]! := by + change (acc.1.coefficients.set i t).val[k.val]! = acc.1.coefficients.val[k.val]! + have h_ne : (acc.1.coefficients.set i t)[k.val]! = (acc.1.coefficients)[k.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i k.val t h_i_ne_k + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_ne + -- t1 = a.val[k.val]! (result of index_mut). Bound: t1 = acc.1.coef[k]! = t. + have h_t1_eq_t : a.val[k.val]! = t := by + change a.val[k.val]! = acc.1.coefficients.val[k.val]! + exact h_a_k_val_eq + have h_add_pre : ∀ ℓ : Nat, ℓ < 16 → + (((a.val[k.val]!).elements.val[ℓ]!).val + (scratch2.elements.val[ℓ]!).val : Int).natAbs + ≤ 2 ^ 15 - 1 := by + intro ℓ hℓ + rw [h_t1_eq_t] + have h_t_b := h_t_bd ℓ hℓ + have h_s2_b := (h_scratch2_bd ℓ hℓ).2 + have h_t_int_lb : -(3 : Int) ≤ (t.elements.val[ℓ]!).val := by omega + have h_t_int_ub : (t.elements.val[ℓ]!).val ≤ (3 : Int) := by omega + have h_s2_int_lb : -(4800 : Int) ≤ (scratch2.elements.val[ℓ]!).val := by omega + have h_s2_int_ub : (scratch2.elements.val[ℓ]!).val ≤ (4800 : Int) := by omega + omega + obtain ⟨t2, h_t2_eq, h_t2_post⟩ := + triple_exists_ok_l3 (add_spec (a.val[k.val]!) scratch2 h_add_pre) + have h_t2_bd : ∀ ℓ : Nat, ℓ < 16 → (t2.elements.val[ℓ]!).val.natAbs ≤ 4803 := by + intro ℓ hℓ + have h_per := h_t2_post ℓ hℓ + have h_v := h_per.1 + rw [h_t1_eq_t] at h_v + have h_t_b := h_t_bd ℓ hℓ + have h_s2_b := (h_scratch2_bd ℓ hℓ).2 + have h_t_int_lb : -(3 : Int) ≤ (t.elements.val[ℓ]!).val := by omega + have h_t_int_ub : (t.elements.val[ℓ]!).val ≤ (3 : Int) := by omega + have h_s2_int_lb : -(4800 : Int) ≤ (scratch2.elements.val[ℓ]!).val := by omega + have h_s2_int_ub : (scratch2.elements.val[ℓ]!).val ≤ (4800 : Int) := by omega + omega + -- 8/9) a1 = a.set k t2 (definitional); (t3, back2) = index_mut_usize a1 i. + -- Use `have ... : ... := ` to define a1 as a syntactically-distinct term, + -- but state the index_mut hypothesis directly with `a.set k t2` so simp + -- can match the goal pattern. + have h_a1_i_lt : i.val < (a.set k t2).length := by + rw [Std.Array.set_length] + change i.val < (acc.1.coefficients.set i t).length + rw [Std.Array.set_length, h_coef_len]; exact h_i_lt_16 + have h_a1_i_idx : + Aeneas.Std.Array.index_usize (a.set k t2) i + = .ok ((a.set k t2).val[i.val]!) := + array_index_usize_ok_eq (a.set k t2) i h_a1_i_lt + have h_imt_a1_i : + Aeneas.Std.Array.index_mut_usize (a.set k t2) i + = .ok ((a.set k t2).val[i.val]!, (a.set k t2).set i) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_a1_i_idx]; rfl + have h_a1_i_val_eq : (a.set k t2).val[i.val]! = t := by + have h_set_k_ne : (a.set k t2)[i.val]! = a[i.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a k i.val t2 h_k_ne_i + have h_a_i_eq : a[i.val]! = t := by + change (acc.1.coefficients.set i t)[i.val]! = t + exact Aeneas.Std.Array.getElem!_Nat_set_eq acc.1.coefficients i i.val t ⟨rfl, h_i_lt_coef⟩ + have h_chain : (a.set k t2).val[i.val]! = a.val[i.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_k_ne + rw [h_chain] + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_a_i_eq + -- 10) t4 = sub t3 scratch2. State with the raw `a.set k t2` term (no `let a1`). + have h_sub_pre : ∀ ℓ : Nat, ℓ < 16 → + ((((a.set k t2).val[i.val]!).elements.val[ℓ]!).val + - (scratch2.elements.val[ℓ]!).val : Int).natAbs ≤ 2 ^ 15 - 1 := by + intro ℓ hℓ + rw [h_a1_i_val_eq] + have h_t_b := h_t_bd ℓ hℓ + have h_s2_b := (h_scratch2_bd ℓ hℓ).2 + have h_t_int_lb : -(3 : Int) ≤ (t.elements.val[ℓ]!).val := by omega + have h_t_int_ub : (t.elements.val[ℓ]!).val ≤ (3 : Int) := by omega + have h_s2_int_lb : -(4800 : Int) ≤ (scratch2.elements.val[ℓ]!).val := by omega + have h_s2_int_ub : (scratch2.elements.val[ℓ]!).val ≤ (4800 : Int) := by omega + omega + obtain ⟨t4, h_t4_eq, h_t4_post⟩ := + triple_exists_ok_l3 (sub_spec ((a.set k t2).val[i.val]!) scratch2 h_sub_pre) + have h_t4_bd : ∀ ℓ : Nat, ℓ < 16 → (t4.elements.val[ℓ]!).val.natAbs ≤ 4803 := by + intro ℓ hℓ + have h_per := h_t4_post ℓ hℓ + have h_v := h_per.1 + rw [h_a1_i_val_eq] at h_v + have h_t_b := h_t_bd ℓ hℓ + have h_s2_b := (h_scratch2_bd ℓ hℓ).2 + have h_t_int_lb : -(3 : Int) ≤ (t.elements.val[ℓ]!).val := by omega + have h_t_int_ub : (t.elements.val[ℓ]!).val ≤ (3 : Int) := by omega + have h_s2_int_lb : -(4800 : Int) ≤ (scratch2.elements.val[ℓ]!).val := by omega + have h_s2_int_ub : (scratch2.elements.val[ℓ]!).val ≤ (4800 : Int) := by omega + omega + -- Introduce a1 alias AFTER stating h_t4_eq (so h_t4_eq has raw form). + let a1 : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a.set k t2 + have ha1_def : a1 = a.set k t2 := rfl + -- 11) a2 = a1.set i t4. + set a2 : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a1.set i t4 with ha2_def + set acc' : Layer7.Acc := ({ coefficients := a2 }, scratch2) with hacc'_def + -- Compose into a single .ok equation. + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + 8#usize { start := k, «end» := 8#usize } acc.1 acc.2 + = .ok (cont (({ start := s, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + -- L3.1's pattern: full `simp` reduces match-on-Some + lets + Prod-binds + -- and threads all arithmetic equations in one call. Disable two simp + -- lemmas that would push the goal into a form our hypotheses don't match: + -- * `List.getElem!_eq_getElem?_getD` — would unfold `[i]!` to `?.getD default`. + -- * `Std.Array.set_val_eq` — would push `↑` inside `(a.set i x)` to + -- yield `(↑a).set (↑i) x`, breaking the `(a.set k t2).val[i.val]!` + -- pattern in our index_mut hypothesis. + simp [-List.getElem!_eq_getElem?_getD, -Std.Array.set_val_eq, + bind_tc_ok, + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.multiply_by_constant, + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.add, + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.sub, + h_i_eq, h_idx_i, h_scratch2_eq, h_idx_k, h_upd_i, h_imt_a_k, h_t2_eq, + h_imt_a1_i, h_t4_eq] + -- Goal collapses to `({coefficients := (a.set k t2).set i t4}, scratch2) = acc'`; + -- `acc'` is `({coefficients := (a.set k t2).set i t4}, scratch2)` definitionally + -- (via the `let`-chain a1 = a.set k t2, a2 = a1.set i t4). + rfl + apply triple_of_ok_l3 h_body + show Layer7.step_post re k + (.cont (({ start := s, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold Layer7.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_, ?_⟩ + · -- Lanes j ∈ [0, s.val): processed. + intro j hj ℓ hℓ + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k: unchanged by both writes. + have h_k_ne_j : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_i_ne_j : i.val ≠ j := by rw [h_i_val_arith]; omega + have h_chain : + acc'.1.coefficients.val[j]! = acc.1.coefficients.val[j]! := by + show (a1.set i t4).val[j]! = acc.1.coefficients.val[j]! + have h1 : (a1.set i t4)[j]! = a1[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i j t4 h_i_ne_j + have h2 : (a.set k t2)[j]! = a[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a k j t2 h_k_ne_j + have h3 : (acc.1.coefficients.set i t)[j]! = (acc.1.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i j t h_i_ne_j + have h1' : (a1.set i t4).val[j]! = a1.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h1 + have h2' : a1.val[j]! = a.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h2 + have h3' : a.val[j]! = acc.1.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h3 + rw [h1', h2', h3'] + rw [h_chain] + exact h_done_lo j hj_lt_k ℓ hℓ + · -- j = k: new value is t2. + subst hj_eq_k + have h_chain : acc'.1.coefficients.val[k.val]! = t2 := by + show (a1.set i t4).val[k.val]! = t2 + have h1 : (a1.set i t4)[k.val]! = a1[k.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i k.val t4 h_i_ne_k + have h1' : (a1.set i t4).val[k.val]! = a1.val[k.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h1 + rw [h1'] + show (a.set k t2).val[k.val]! = t2 + have h2 : (a.set k t2)[k.val]! = t2 := + Aeneas.Std.Array.getElem!_Nat_set_eq a k k.val t2 ⟨rfl, h_a_k_lt⟩ + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h2 + rw [h_chain] + exact h_t2_bd ℓ hℓ + · -- Lanes j ∈ [8, 8 + s.val): processed. + intro j hj_lo hj_hi ℓ hℓ + rw [hs_val] at hj_hi + have hj_lt : j < 8 + k.val + 1 := by omega + rcases Nat.lt_succ_iff_lt_or_eq.mp hj_lt with hj_lt_ki | hj_eq_ki + · -- j ∈ [8, 8 + k.val): unchanged by both writes. + have h_k_ne_j : k.val ≠ j := by omega + have h_i_ne_j : i.val ≠ j := by rw [h_i_val_arith]; omega + have h_chain : + acc'.1.coefficients.val[j]! = acc.1.coefficients.val[j]! := by + show (a1.set i t4).val[j]! = acc.1.coefficients.val[j]! + have h1 : (a1.set i t4)[j]! = a1[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i j t4 h_i_ne_j + have h2 : (a.set k t2)[j]! = a[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a k j t2 h_k_ne_j + have h3 : (acc.1.coefficients.set i t)[j]! = (acc.1.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i j t h_i_ne_j + have h1' : (a1.set i t4).val[j]! = a1.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h1 + have h2' : a1.val[j]! = a.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h2 + have h3' : a.val[j]! = acc.1.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h3 + rw [h1', h2', h3'] + rw [h_chain] + exact h_done_hi j hj_lo hj_lt_ki ℓ hℓ + · -- j = 8 + k.val = i.val: new value is t4. + have hj_eq_i : j = i.val := by rw [h_i_val_arith]; omega + subst hj_eq_i + have h_chain : acc'.1.coefficients.val[i.val]! = t4 := by + show (a1.set i t4).val[i.val]! = t4 + have h1 : (a1.set i t4)[i.val]! = t4 := + Aeneas.Std.Array.getElem!_Nat_set_eq a1 i i.val t4 ⟨rfl, h_a1_i_lt⟩ + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h1 + rw [h_chain] + exact h_t4_bd ℓ hℓ + · -- Lanes j ∈ [s.val, 8): untouched. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_k_lt_j : k.val < j := by omega + have h_k_ne_j : k.val ≠ j := by omega + have h_i_ne_j : i.val ≠ j := by rw [h_i_val_arith]; omega + have h_chain : + acc'.1.coefficients.val[j]! = acc.1.coefficients.val[j]! := by + show (a1.set i t4).val[j]! = acc.1.coefficients.val[j]! + have h1 : (a1.set i t4)[j]! = a1[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i j t4 h_i_ne_j + have h2 : (a.set k t2)[j]! = a[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a k j t2 h_k_ne_j + have h3 : (acc.1.coefficients.set i t)[j]! = (acc.1.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i j t h_i_ne_j + have h1' : (a1.set i t4).val[j]! = a1.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h1 + have h2' : a1.val[j]! = a.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h2 + have h3' : a.val[j]! = acc.1.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h3 + rw [h1', h2', h3'] + rw [h_chain] + have h_undone_j : k.val ≤ j := by omega + exact h_undone_lo j h_undone_j hj_lt + · -- Lanes j ∈ [8 + s.val, 16): untouched. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_i_lt_j : i.val < j := by rw [h_i_val_arith]; omega + have h_k_ne_j : k.val ≠ j := by omega + have h_i_ne_j : i.val ≠ j := by omega + have h_chain : + acc'.1.coefficients.val[j]! = acc.1.coefficients.val[j]! := by + show (a1.set i t4).val[j]! = acc.1.coefficients.val[j]! + have h1 : (a1.set i t4)[j]! = a1[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a1 i j t4 h_i_ne_j + have h2 : (a.set k t2)[j]! = a[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne a k j t2 h_k_ne_j + have h3 : (acc.1.coefficients.set i t)[j]! = (acc.1.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients i j t h_i_ne_j + have h1' : (a1.set i t4).val[j]! = a1.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h1 + have h2' : a1.val[j]! = a.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h2 + have h3' : a.val[j]! = acc.1.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h3 + rw [h1', h2', h3'] + rw [h_chain] + have h_undone_j : 8 + k.val ≤ j := by omega + exact h_undone_hi j h_undone_j hj_lt + · -- None branch (k ≥ 8). + have hk_ge : k.val ≥ (8#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 8 := by rw [h8] at hk_ge; omega + have h_iter_none := iter_next_none_eq_gen k 8#usize hk_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + 8#usize { start := k, «end» := 8#usize } acc.1 acc.2 + = .ok (done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 8#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 8#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_l3 h_body + show Layer7.step_post re k (.done acc) + unfold Layer7.step_post + show (Layer7.inv re 8#usize acc).holds + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_, ?_⟩ + · intro j hj ℓ hℓ + rw [show (8#usize : Std.Usize).val = 8 from rfl] at hj + apply h_done_lo j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_lo hj_hi ℓ hℓ + rw [show (8#usize : Std.Usize).val = 8 from rfl] at hj_hi + apply h_done_hi j hj_lo _ ℓ hℓ; rw [hk_eq]; exact hj_hi + · intro j hj_ge hj_lt + rw [show (8#usize : Std.Usize).val = 8 from rfl] at hj_ge + apply h_undone_lo j _ hj_lt; rw [hk_eq]; exact hj_ge + · intro j hj_ge hj_lt + rw [show (8#usize : Std.Usize).val = 8 from rfl] at hj_ge + apply h_undone_hi j _ hj_lt; rw [hk_eq]; exact hj_ge + +private theorem vectors_in_ring_element_eq : + libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + +set_option maxHeartbeats 16000000 in +@[spec] +theorem ntt_at_layer_7_spec + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_7 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + re scratch + ⦃ ⇓ p => ⌜ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 4803 ⌝ ⦄ := by + -- Reduce the top wrapper: i = VECTORS_IN_RING_ELEMENT = 16#usize, step = i/2 = 8#usize. + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7 + rw [vectors_in_ring_element_eq] + simp only [bind_tc_ok] + -- step = 16#usize / 2#usize = 8#usize. + have h_two_nz : (2#usize : Std.Usize).val ≠ 0 := by decide + obtain ⟨step, h_step_eq, h_step_val⟩ := + usize_div_ok_eq (16#usize : Std.Usize) (2#usize : Std.Usize) h_two_nz + have h_step_val_8 : step.val = 8 := by + rw [h_step_val]; decide + have h_step_eq_8 : step = 8#usize := by + apply Std.UScalar.eq_of_val_eq + rw [h_step_val_8]; rfl + rw [h_step_eq] + simp only [bind_tc_ok] + rw [h_step_eq_8] + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_7_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + 8#usize iter1 acc1.1 acc1.2) + (β := Layer7.Acc) + (re, scratch) + 0#usize 8#usize + (Layer7.inv re) + (by decide : (0#usize : Std.Usize).val ≤ (8#usize : Std.Usize).val) + (pure_prop_holds_l3 + ⟨fun j hj _ _ => absurd hj (Nat.not_lt_zero j), + fun j _hj_lo hj_hi _ _ => by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0] at hj_hi + omega, + fun _ _ _ => rfl, + fun _ _ _ => rfl⟩) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_done_lo, h_done_hi, _h_undone_lo, _h_undone_hi⟩ := of_pure_prop_holds_l3 h + intro i hi j hj + have h8 : (8#usize : Std.Usize).val = 8 := rfl + by_cases hi_lt_8 : i < 8 + · apply h_done_lo i (by rw [h8]; exact hi_lt_8) j hj + · have hi_ge_8 : 8 ≤ i := Nat.not_lt.mp hi_lt_8 + apply h_done_hi i hi_ge_8 (by rw [h8]; omega) j hj + · -- Step lemma application. + intro acc k h_ge h_le hinv + obtain ⟨h_done_lo, h_done_hi, h_undone_lo, h_undone_hi⟩ := of_pure_prop_holds_l3 hinv + have h_step := ntt_at_layer_7_step_lemma re h_pre acc k h_le + h_done_lo h_done_hi h_undone_lo h_undone_hi + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer7.step_post re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer7.step_post] using hP + · have hP : Layer7.step_post re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer7.step_post] using hP + +/-! ## L3.4 — `ntt_at_layer_4_plus_spec` + +Generic outer-NTT layer for `layer ∈ {4, 5, 6}`. Nested loop: + - Outer loop iterates `outer_count = 128 >>> layer` rounds. + - Inner loop iterates `step_vec = (1 <<< layer) / 16` butterfly positions. + +For `layer = 4`: step=16, step_vec=1, outer_count=8. +For `layer = 5`: step=32, step_vec=2, outer_count=4. +For `layer = 6`: step=64, step_vec=4, outer_count=2. + +In all cases, the total butterflies executed = `outer_count * step_vec = 8`, +covering 16 coefficients (each lane touched once via pairs spanning `step_vec` +apart). Per-coefficient bound goes `bnd → bnd + 3328`. + +Per-iter body (inner loop, j = inner-counter): + i = a_offset + j (one of the "low" lanes of the butterfly pair) + i1 = b_offset + j (one of the "high" lanes) + i2 = polynomial.zeta zeta_i (the Montgomery-domain twiddle) + ntt_layer_int_vec_step OpsInst re.coefficients i i1 scratch i2 + +Per-iter body (outer loop, round = outer-counter): + zeta_i1 = zeta_i + 1 + i = round * 2 + a_offset = i * step_vec + b_offset = a_offset + step_vec + ntt_at_layer_4_plus_loop0_loop0 ... {0..step_vec} zeta_i1 ... +-/ + +/-! ### per-step helper `ntt_layer_int_vec_step_spec` -/ + +/-- A single butterfly on a `coefficients` array at lanes `a ≠ b`: + new `coefficients[a]` = old `coefficients[a]` + zeta*coefficients[b] + new `coefficients[b]` = old `coefficients[a]` - zeta*coefficients[b] + Per-element bound transformation: if `|coefficients[a][ℓ]| ≤ B*3328` + (`B ≤ 4`, so `B+1 ≤ 5`, well-within `9*3328 < 2^15`), the output at + both lanes a and b is bounded by `(B+1)*3328`. Other lanes unchanged. + Returns the new coefficients and the scratch value (unused downstream). -/ +private theorem ntt_layer_int_vec_step_spec + (coefficients : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize) + (a b : Std.Usize) (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_r : Std.I16) + (h_a : a.val < 16) (h_b : b.val < 16) (h_ne : a.val ≠ b.val) + (h_zeta : zeta_r.val.natAbs ≤ 1664) + (bnd : Nat) (h_bnd : bnd ≤ 8 * 3328) + (h_pre_a : ∀ ℓ : Nat, ℓ < 16 → + ((coefficients.val[a.val]!).elements.val[ℓ]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_layer_int_vec_step + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + coefficients a b scratch zeta_r + ⦃ ⇓ p => ⌜ (∀ ℓ : Nat, ℓ < 16 → + ((p.1.val[a.val]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328) + ∧ (∀ ℓ : Nat, ℓ < 16 → + ((p.1.val[b.val]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328) + ∧ (∀ k : Nat, k < 16 → k ≠ a.val → k ≠ b.val → + p.1.val[k]! = coefficients.val[k]!) ⌝ ⦄ := by + have h_coef_len : coefficients.length = 16 := Std.Array.length_eq _ + have h_a_lt : a.val < coefficients.length := by rw [h_coef_len]; exact h_a + have h_b_lt : b.val < coefficients.length := by rw [h_coef_len]; exact h_b + -- Read coefficients[b]. + have h_idx_b : Aeneas.Std.Array.index_usize coefficients b + = .ok (coefficients.val[b.val]!) := + array_index_usize_ok_eq coefficients b h_b_lt + set scratch1 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + coefficients.val[b.val]! with hs1_def + -- scratch2 = OpsInst.montgomery_multiply_by_constant scratch1 zeta_r. + -- This reduces to: do classify zeta_r; arithmetic.montgomery_multiply_by_constant scratch1 zeta_r. + -- Use L1.4. + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify zeta_r = .ok zeta_r := rfl + obtain ⟨scratch2, h_scratch2_eq, h_scratch2_post⟩ := + triple_exists_ok_l3 (montgomery_multiply_by_constant_spec scratch1 zeta_r h_zeta) + have h_scratch2_bd : ∀ ℓ : Nat, ℓ < 16 → (scratch2.elements.val[ℓ]!).val.natAbs ≤ 3328 := by + intro ℓ hℓ; exact (h_scratch2_post ℓ hℓ).1 + -- Read coefficients[a] (= t). + have h_idx_a : Aeneas.Std.Array.index_usize coefficients a + = .ok (coefficients.val[a.val]!) := + array_index_usize_ok_eq coefficients a h_a_lt + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + coefficients.val[a.val]! with ht_def + -- Bound at t: per h_pre_a, |t[ℓ]| ≤ bnd. + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → (t.elements.val[ℓ]!).val.natAbs ≤ bnd := + h_pre_a + -- coefficients1 = coefficients.update b t. + have h_upd_b : Aeneas.Std.Array.update coefficients b t + = .ok (coefficients.set b t) := + array_update_ok_eq coefficients b t h_b_lt + set c1 : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + coefficients.set b t with hc1_def + -- index_mut_usize c1 a returns (c1.val[a.val]!, c1.set a). + have h_a_ne_b : a.val ≠ b.val := h_ne + have h_b_ne_a : b.val ≠ a.val := fun h => h_a_ne_b h.symm + have h_c1_a_lt : a.val < c1.length := by + change a.val < (coefficients.set b t).length; rw [Std.Array.set_length]; exact h_a_lt + have h_c1_a_idx : Aeneas.Std.Array.index_usize c1 a = .ok (c1.val[a.val]!) := + array_index_usize_ok_eq c1 a h_c1_a_lt + have h_imt_c1_a : Aeneas.Std.Array.index_mut_usize c1 a + = .ok (c1.val[a.val]!, c1.set a) := by + unfold Aeneas.Std.Array.index_mut_usize; rw [h_c1_a_idx]; rfl + have h_c1_a_val_eq : c1.val[a.val]! = t := by + show (coefficients.set b t).val[a.val]! = t + have h_ne_set : (coefficients.set b t)[a.val]! = coefficients[a.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne coefficients b a.val t h_b_ne_a + have h_eq_val : (coefficients.set b t).val[a.val]! = coefficients.val[a.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_ne_set + rw [h_eq_val] + -- t2 = OpsInst.add (c1.val[a.val]!) scratch2. + have h_add_pre : ∀ ℓ : Nat, ℓ < 16 → + (((c1.val[a.val]!).elements.val[ℓ]!).val + + (scratch2.elements.val[ℓ]!).val : Int).natAbs ≤ 2 ^ 15 - 1 := by + intro ℓ hℓ + rw [h_c1_a_val_eq] + have h_t_b := h_t_bd ℓ hℓ + have h_s2_b := h_scratch2_bd ℓ hℓ + have h_t_int_lb : -((bnd : Nat) : Int) ≤ (t.elements.val[ℓ]!).val := by omega + have h_t_int_ub : (t.elements.val[ℓ]!).val ≤ ((bnd : Nat) : Int) := by omega + have h_s2_int_lb : -(3328 : Int) ≤ (scratch2.elements.val[ℓ]!).val := by omega + have h_s2_int_ub : (scratch2.elements.val[ℓ]!).val ≤ (3328 : Int) := by omega + -- bnd + 3328 ≤ 4*3328 + 3328 = 5*3328 = 16640 < 2^15. + omega + obtain ⟨t2, h_t2_eq, h_t2_post⟩ := + triple_exists_ok_l3 (add_spec (c1.val[a.val]!) scratch2 h_add_pre) + -- Bound on t2: ≤ bnd + 3328. + have h_t2_bd : ∀ ℓ : Nat, ℓ < 16 → (t2.elements.val[ℓ]!).val.natAbs ≤ bnd + 3328 := by + intro ℓ hℓ + have h_per := h_t2_post ℓ hℓ + have h_v := h_per.1 + rw [h_c1_a_val_eq] at h_v + have h_t_b := h_t_bd ℓ hℓ + have h_s2_b := h_scratch2_bd ℓ hℓ + have h_t_int_lb : -((bnd : Nat) : Int) ≤ (t.elements.val[ℓ]!).val := by omega + have h_t_int_ub : (t.elements.val[ℓ]!).val ≤ ((bnd : Nat) : Int) := by omega + have h_s2_int_lb : -(3328 : Int) ≤ (scratch2.elements.val[ℓ]!).val := by omega + have h_s2_int_ub : (scratch2.elements.val[ℓ]!).val ≤ (3328 : Int) := by omega + omega + -- coefficients2 = c1.set a t2. + set c2 : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + c1.set a t2 with hc2_def + -- index_mut_usize c2 b returns (c2.val[b.val]!, c2.set b). + have h_c2_b_lt : b.val < c2.length := by + change b.val < (c1.set a t2).length; rw [Std.Array.set_length] + change b.val < (coefficients.set b t).length; rw [Std.Array.set_length]; exact h_b_lt + have h_c2_b_idx : Aeneas.Std.Array.index_usize c2 b = .ok (c2.val[b.val]!) := + array_index_usize_ok_eq c2 b h_c2_b_lt + have h_imt_c2_b : Aeneas.Std.Array.index_mut_usize c2 b + = .ok (c2.val[b.val]!, c2.set b) := by + unfold Aeneas.Std.Array.index_mut_usize; rw [h_c2_b_idx]; rfl + -- c2.val[b.val]! = c1.val[b.val]! (since c2 sets a ≠ b) + -- = (coefficients.set b t).val[b.val]! = t. + have h_c2_b_val_eq : c2.val[b.val]! = t := by + show (c1.set a t2).val[b.val]! = t + have h_ne1 : (c1.set a t2)[b.val]! = c1[b.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne c1 a b.val t2 h_a_ne_b + have h_ne1_val : (c1.set a t2).val[b.val]! = c1.val[b.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_ne1 + rw [h_ne1_val] + change (coefficients.set b t).val[b.val]! = t + have h_eq : (coefficients.set b t)[b.val]! = t := + Aeneas.Std.Array.getElem!_Nat_set_eq coefficients b b.val t ⟨rfl, h_b_lt⟩ + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_eq + -- t4 = OpsInst.sub (c2.val[b.val]!) scratch2. + have h_sub_pre : ∀ ℓ : Nat, ℓ < 16 → + (((c2.val[b.val]!).elements.val[ℓ]!).val + - (scratch2.elements.val[ℓ]!).val : Int).natAbs ≤ 2 ^ 15 - 1 := by + intro ℓ hℓ + rw [h_c2_b_val_eq] + have h_t_b := h_t_bd ℓ hℓ + have h_s2_b := h_scratch2_bd ℓ hℓ + have h_t_int_lb : -((bnd : Nat) : Int) ≤ (t.elements.val[ℓ]!).val := by omega + have h_t_int_ub : (t.elements.val[ℓ]!).val ≤ ((bnd : Nat) : Int) := by omega + have h_s2_int_lb : -(3328 : Int) ≤ (scratch2.elements.val[ℓ]!).val := by omega + have h_s2_int_ub : (scratch2.elements.val[ℓ]!).val ≤ (3328 : Int) := by omega + omega + obtain ⟨t4, h_t4_eq, h_t4_post⟩ := + triple_exists_ok_l3 (sub_spec (c2.val[b.val]!) scratch2 h_sub_pre) + have h_t4_bd : ∀ ℓ : Nat, ℓ < 16 → (t4.elements.val[ℓ]!).val.natAbs ≤ bnd + 3328 := by + intro ℓ hℓ + have h_per := h_t4_post ℓ hℓ + have h_v := h_per.1 + rw [h_c2_b_val_eq] at h_v + have h_t_b := h_t_bd ℓ hℓ + have h_s2_b := h_scratch2_bd ℓ hℓ + have h_t_int_lb : -((bnd : Nat) : Int) ≤ (t.elements.val[ℓ]!).val := by omega + have h_t_int_ub : (t.elements.val[ℓ]!).val ≤ ((bnd : Nat) : Int) := by omega + have h_s2_int_lb : -(3328 : Int) ≤ (scratch2.elements.val[ℓ]!).val := by omega + have h_s2_int_ub : (scratch2.elements.val[ℓ]!).val ≤ (3328 : Int) := by omega + omega + -- coefficients3 = c2.set b t4. + set c3 : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + c2.set b t4 with hc3_def + -- Compose into single .ok equation. + have h_body : + libcrux_iot_ml_kem.ntt.ntt_layer_int_vec_step + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + coefficients a b scratch zeta_r + = .ok (c3, scratch2) := by + unfold libcrux_iot_ml_kem.ntt.ntt_layer_int_vec_step + -- mont_mul_fe reduces to: classify zeta_r >>= λi → arithmetic.mont_mul scratch1 i. + simp only [bind_tc_ok, h_idx_b] + -- Force unfold the trait wrappers (montgomery_multiply_fe, .add, .sub). + unfold libcrux_iot_ml_kem.vector.traits.montgomery_multiply_fe + show + (do + let _scratch2 ← + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.montgomery_multiply_by_constant + scratch1 zeta_r + let _t ← Aeneas.Std.Array.index_usize coefficients a + let _c1 ← Aeneas.Std.Array.update coefficients b _t + let (_t1, _back1) ← Aeneas.Std.Array.index_mut_usize _c1 a + let _t2 ← + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.add + _t1 _scratch2 + let _c2 := _back1 _t2 + let (_t3, _back2) ← Aeneas.Std.Array.index_mut_usize _c2 b + let _t4 ← + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.sub + _t3 _scratch2 + let _c3 := _back2 _t4 + ok (_c3, _scratch2)) + = _ + unfold + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.montgomery_multiply_by_constant + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.add + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.sub + -- Now the body is in terms of `classify`, `arithmetic.mont_mul_by_const`, `arithmetic.add`, `arithmetic.sub`. + simp only [bind_tc_ok, h_classify] + -- Now scratch2 = arithmetic.mont_mul_by_const scratch1 zeta_r; + -- L1.4 gives h_scratch2_eq for that. + rw [h_scratch2_eq] + simp only [bind_tc_ok, h_idx_a, h_upd_b, h_imt_c1_a, h_t2_eq] + -- After threading through index_mut_back1 := c1.set a, the remaining + -- shape is `(c1.set a t2).index_mut_usize b >>= …`. Now apply h_imt_c2_b + -- (whose LHS is `c2.index_mut_usize b` = `(c1.set a t2).index_mut_usize b`). + show ((c1.set a t2).index_mut_usize b >>= _) = _ + rw [show (c1.set a t2).index_mut_usize b = c2.index_mut_usize b from rfl, h_imt_c2_b] + simp only [bind_tc_ok, h_t4_eq] + rfl + apply triple_of_ok_l3 h_body + -- Now prove the post for (c3, scratch2): + -- 1) c3[a] bounded by (B+1)*3328 + -- 2) c3[b] bounded by (B+1)*3328 + -- 3) c3[k] = coefficients[k] for k ≠ a, k ≠ b. + -- Key chain: c3 = c2.set b t4, c2 = c1.set a t2, c1 = coefficients.set b t. + have h_c3_a_val_eq : c3.val[a.val]! = t2 := by + show (c2.set b t4).val[a.val]! = t2 + have h_ne1 : (c2.set b t4)[a.val]! = c2[a.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne c2 b a.val t4 h_b_ne_a + have h_ne1_val : (c2.set b t4).val[a.val]! = c2.val[a.val]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_ne1 + rw [h_ne1_val] + -- c2.val[a.val]! = (c1.set a t2)[a.val]! = t2. + show (c1.set a t2).val[a.val]! = t2 + have h_eq : (c1.set a t2)[a.val]! = t2 := by + have h_a_lt_c1 : a.val < c1.length := h_c1_a_lt + exact Aeneas.Std.Array.getElem!_Nat_set_eq c1 a a.val t2 ⟨rfl, h_a_lt_c1⟩ + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_eq + have h_c3_b_val_eq : c3.val[b.val]! = t4 := by + show (c2.set b t4).val[b.val]! = t4 + have h_eq : (c2.set b t4)[b.val]! = t4 := + Aeneas.Std.Array.getElem!_Nat_set_eq c2 b b.val t4 ⟨rfl, h_c2_b_lt⟩ + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_eq + refine ⟨?_, ?_, ?_⟩ + · intro ℓ hℓ + rw [h_c3_a_val_eq]; exact h_t2_bd ℓ hℓ + · intro ℓ hℓ + rw [h_c3_b_val_eq]; exact h_t4_bd ℓ hℓ + · intro k h_k_lt h_k_ne_a h_k_ne_b + -- c3[k] = (c2.set b t4)[k] = c2[k] (k ≠ b) + -- = (c1.set a t2)[k] = c1[k] (k ≠ a) + -- = (coefficients.set b t)[k] = coefficients[k] (k ≠ b). + show (c2.set b t4).val[k]! = coefficients.val[k]! + have h1 : (c2.set b t4)[k]! = c2[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne c2 b k t4 (fun h => h_k_ne_b h.symm) + have h1' : (c2.set b t4).val[k]! = c2.val[k]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h1 + rw [h1'] + show (c1.set a t2).val[k]! = coefficients.val[k]! + have h2 : (c1.set a t2)[k]! = c1[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne c1 a k t2 (fun h => h_k_ne_a h.symm) + have h2' : (c1.set a t2).val[k]! = c1.val[k]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h2 + rw [h2'] + show (coefficients.set b t).val[k]! = coefficients.val[k]! + have h3 : (coefficients.set b t)[k]! = coefficients[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne coefficients b k t (fun h => h_k_ne_b h.symm) + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h3 + +/-! ### inner loop helper + +The inner loop `ntt_at_layer_4_plus_loop0_loop0` iterates over `j ∈ [0, step_vec)`, +each iter calling `ntt_layer_int_vec_step` on lanes `(a_offset+j, b_offset+j)` +with a **fixed** `zeta_r = polynomial.zeta zeta_i`. The invariant after `j` iters +has four zones (all bounds in absolute value): + - lanes `[a_offset, a_offset+j)` and `[b_offset, b_offset+j)`: processed, + each bounded by `(B+1)*3328`. + - lanes `[a_offset+j, a_offset+step_vec)` and `[b_offset+j, b_offset+step_vec)`: + untouched, equal to `re.coefficients` at the same index (so bounded by `B*3328`). + - other lanes: untouched. + +We require `[a_offset, a_offset+step_vec)` and `[b_offset, b_offset+step_vec)` +to be disjoint and lie within `[0, 16)` — i.e. `a_offset + step_vec ≤ b_offset` +and `b_offset + step_vec ≤ 16` (with `a_offset ≤ b_offset` from L3.4's caller). +-/ + +namespace Layer4PlusInner + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +/-- Inner-loop accumulator: a `(PolynomialRingElement × scratch)`. -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector × + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- Inner-loop invariant after `j` iters. `bnd` is the absolute input bound; + processed lanes are at `bnd + 3328`. -/ +def inv + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset step_vec : Std.Usize) (bnd : Nat) : + Std.Usize → Acc → Result Prop := + fun j acc => pure ( + -- Processed-a zone: lanes [a_offset, a_offset + j). + (∀ ℓ' : Nat, ℓ' < j.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[a_offset.val + ℓ']!).elements.val[ℓ]!).val.natAbs + ≤ bnd + 3328) + -- Processed-b zone: lanes [b_offset, b_offset + j). + ∧ (∀ ℓ' : Nat, ℓ' < j.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[b_offset.val + ℓ']!).elements.val[ℓ]!).val.natAbs + ≤ bnd + 3328) + -- Untouched lanes [a_offset+j, a_offset+step_vec) match re. + ∧ (∀ ℓ' : Nat, j.val ≤ ℓ' → ℓ' < step_vec.val → + acc.1.coefficients.val[a_offset.val + ℓ']! + = re.coefficients.val[a_offset.val + ℓ']!) + -- Untouched lanes [b_offset+j, b_offset+step_vec) match re. + ∧ (∀ ℓ' : Nat, j.val ≤ ℓ' → ℓ' < step_vec.val → + acc.1.coefficients.val[b_offset.val + ℓ']! + = re.coefficients.val[b_offset.val + ℓ']!) + -- Lanes outside the two ranges are unchanged from re. + ∧ (∀ k : Nat, k < 16 → + (k < a_offset.val ∨ a_offset.val + step_vec.val ≤ k ∧ k < b_offset.val + ∨ b_offset.val + step_vec.val ≤ k) → + acc.1.coefficients.val[k]! = re.coefficients.val[k]!)) + +/-- Inner-loop step post. -/ +def step_post + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (a_offset b_offset step_vec : Std.Usize) (bnd : Nat) (j : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + j.val < step_vec.val ∧ iter'.«end» = step_vec + ∧ iter'.start.val = j.val + 1 + ∧ (inv re a_offset b_offset step_vec bnd iter'.start acc').holds + | .done y => (inv re a_offset b_offset step_vec bnd step_vec y).holds + +end Layer4PlusInner + +/-- Inner-loop step lemma. Each body iter calls `ntt_layer_int_vec_step` on + lanes `(a_offset+j, b_offset+j)`, transforming both bounds from `bnd` to + `bnd + 3328`. We only need the bound on `re` for lanes in the + a-window `[a_offset, a_offset+step_vec)`. -/ +private theorem ntt_at_layer_4_plus_inner_step_lemma + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i a_offset b_offset step_vec : Std.Usize) (bnd : Nat) (h_bnd : bnd ≤ 8 * 3328) + (h_zeta_lt : zeta_i.val < 128) + (h_ranges : a_offset.val + step_vec.val ≤ b_offset.val + ∧ b_offset.val + step_vec.val ≤ 16) + (h_pre_a : ∀ ℓ' : Nat, ℓ' < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[a_offset.val + ℓ']!).elements.val[ℓ]!).val.natAbs ≤ bnd) + (acc : Layer4PlusInner.Acc) + (j : Std.Usize) (h_le : j.val ≤ step_vec.val) + (hinv : (Layer4PlusInner.inv re a_offset b_offset step_vec bnd j acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i a_offset b_offset { start := j, «end» := step_vec } acc.1 acc.2 + ⦃ ⇓ r => ⌜ Layer4PlusInner.step_post re a_offset b_offset step_vec bnd j r ⌝ ⦄ := by + obtain ⟨h_a_disj, h_b_le_16⟩ := h_ranges + -- The 5 invariant conjuncts. + obtain ⟨h_a_done, h_b_done, h_a_undone, h_b_undone, h_other⟩ := + of_pure_prop_holds_l3 hinv + have h_coef_len : acc.1.coefficients.length = 16 := Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + by_cases h_lt : j.val < step_vec.val + · -- Some j branch. + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq_gen j step_vec h_lt + -- 1) i = a_offset + j. + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_step_vec_le_16 : step_vec.val ≤ 16 := by omega + have h_a_plus_j_lt_16 : a_offset.val + j.val < 16 := by omega + have h_b_plus_j_lt_16 : b_offset.val + j.val < 16 := by omega + have h_i_max : a_offset.val + j.val ≤ Std.Usize.max := by + have : (16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i, h_i_eq, h_i_val⟩ := usize_add_ok_eq a_offset j h_i_max + have h_i_val_arith : i.val = a_offset.val + j.val := h_i_val + -- 2) i1 = b_offset + j. + have h_i1_max : b_offset.val + j.val ≤ Std.Usize.max := by + have : (16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i1, h_i1_eq, h_i1_val⟩ := usize_add_ok_eq b_offset j h_i1_max + have h_i1_val_arith : i1.val = b_offset.val + j.val := h_i1_val + -- 3) i2 = polynomial.zeta zeta_i. + obtain ⟨zeta_r, h_zeta_eq, h_zeta_bd⟩ := + triple_exists_ok_l3 (polynomial.zeta_spec zeta_i h_zeta_lt) + -- 4) ntt_layer_int_vec_step on (i, i1). + have h_i_lt : i.val < 16 := by rw [h_i_val_arith]; exact h_a_plus_j_lt_16 + have h_i1_lt : i1.val < 16 := by rw [h_i1_val_arith]; exact h_b_plus_j_lt_16 + have h_i_ne_i1 : i.val ≠ i1.val := by rw [h_i_val_arith, h_i1_val_arith]; omega + -- Precondition for ntt_layer_int_vec_step: acc[i] (= acc[a_offset+j]) is at B*3328. + -- From h_a_undone, acc.coef[a_offset+j] = re.coef[a_offset+j], hence bounded by h_pre. + have h_acc_i_eq : + acc.1.coefficients.val[i.val]! = re.coefficients.val[i.val]! := by + rw [h_i_val_arith] + exact h_a_undone j.val (Nat.le_refl _) h_lt + have h_pre_i : ∀ ℓ : Nat, ℓ < 16 → + ((acc.1.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs ≤ bnd := by + intro ℓ hℓ + rw [h_acc_i_eq] + show ((re.coefficients.val[i.val]!).elements.val[ℓ]!).val.natAbs ≤ bnd + rw [show i.val = a_offset.val + j.val from h_i_val_arith] + exact h_pre_a j.val h_lt ℓ hℓ + obtain ⟨step_out, h_step_eq, h_step_post⟩ := + triple_exists_ok_l3 + (ntt_layer_int_vec_step_spec acc.1.coefficients i i1 acc.2 zeta_r + h_i_lt h_i1_lt h_i_ne_i1 h_zeta_bd bnd h_bnd h_pre_i) + obtain ⟨h_step_a_bd, h_step_b_bd, h_step_other⟩ := h_step_post + -- Next-state. + set acc' : Layer4PlusInner.Acc := + ({ coefficients := step_out.1 }, step_out.2) with hacc'_def + -- Compose the body into one .ok. + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i a_offset b_offset { start := j, «end» := step_vec } acc.1 acc.2 + = .ok (cont (({ start := s, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := j, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := j, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [bind_tc_ok, h_i_eq, h_i1_eq, h_zeta_eq, h_step_eq, hacc'_def] + -- step_out is a Prod (Array PortableVector 16, PortableVector); the + -- `let (a, scratch1) := step_out` destructure equals (step_out.1, step_out.2) + -- definitionally. + rfl + apply triple_of_ok_l3 h_body + show Layer4PlusInner.step_post re a_offset b_offset step_vec bnd j + (.cont (({ start := s, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold Layer4PlusInner.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l3 + -- We need to show inv at s = j+1. + -- The 5 conjuncts: + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + -- A. acc'.coef[a_offset+ℓ'] bounded for ℓ' < s.val = j.val + 1. + · intro ℓ' hℓ' ℓ hℓ + rw [hs_val] at hℓ' + rcases Nat.lt_succ_iff_lt_or_eq.mp hℓ' with hℓ'_lt_j | hℓ'_eq_j + · -- ℓ' < j.val: unchanged by step (i = a_offset+j; ℓ' < j ⇒ a_offset+ℓ' ≠ i, ≠ i1). + have h_idx_lt_16 : a_offset.val + ℓ' < 16 := by omega + have h_ne_i : a_offset.val + ℓ' ≠ i.val := by rw [h_i_val_arith]; omega + have h_ne_i1 : a_offset.val + ℓ' ≠ i1.val := by rw [h_i1_val_arith]; omega + have h_unchanged : + step_out.1.val[a_offset.val + ℓ']! = acc.1.coefficients.val[a_offset.val + ℓ']! := + h_step_other (a_offset.val + ℓ') h_idx_lt_16 h_ne_i h_ne_i1 + show (step_out.1.val[a_offset.val + ℓ']!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_unchanged] + exact h_a_done ℓ' hℓ'_lt_j ℓ hℓ + · -- ℓ' = j.val: this is lane i = a_offset+j. Apply h_step_a_bd. + subst hℓ'_eq_j + have h_eq : a_offset.val + j.val = i.val := h_i_val_arith.symm + show (step_out.1.val[a_offset.val + j.val]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_eq]; exact h_step_a_bd ℓ hℓ + -- B. acc'.coef[b_offset+ℓ'] bounded for ℓ' < s.val = j.val + 1. + · intro ℓ' hℓ' ℓ hℓ + rw [hs_val] at hℓ' + rcases Nat.lt_succ_iff_lt_or_eq.mp hℓ' with hℓ'_lt_j | hℓ'_eq_j + · have h_idx_lt_16 : b_offset.val + ℓ' < 16 := by omega + have h_ne_i : b_offset.val + ℓ' ≠ i.val := by rw [h_i_val_arith]; omega + have h_ne_i1 : b_offset.val + ℓ' ≠ i1.val := by rw [h_i1_val_arith]; omega + have h_unchanged : + step_out.1.val[b_offset.val + ℓ']! = acc.1.coefficients.val[b_offset.val + ℓ']! := + h_step_other (b_offset.val + ℓ') h_idx_lt_16 h_ne_i h_ne_i1 + show (step_out.1.val[b_offset.val + ℓ']!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_unchanged] + exact h_b_done ℓ' hℓ'_lt_j ℓ hℓ + · subst hℓ'_eq_j + have h_eq : b_offset.val + j.val = i1.val := h_i1_val_arith.symm + show (step_out.1.val[b_offset.val + j.val]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_eq]; exact h_step_b_bd ℓ hℓ + -- C. Untouched a-zone for ℓ' ≥ s.val. + · intro ℓ' hℓ'_ge hℓ'_lt + rw [hs_val] at hℓ'_ge + have hℓ'_gt_j : j.val < ℓ' := by omega + have h_ge' : j.val ≤ ℓ' := Nat.le_of_lt hℓ'_gt_j + have h_idx_lt_16 : a_offset.val + ℓ' < 16 := by omega + have h_ne_i : a_offset.val + ℓ' ≠ i.val := by rw [h_i_val_arith]; omega + have h_ne_i1 : a_offset.val + ℓ' ≠ i1.val := by rw [h_i1_val_arith]; omega + have h_unchanged : + step_out.1.val[a_offset.val + ℓ']! = acc.1.coefficients.val[a_offset.val + ℓ']! := + h_step_other (a_offset.val + ℓ') h_idx_lt_16 h_ne_i h_ne_i1 + show step_out.1.val[a_offset.val + ℓ']! = re.coefficients.val[a_offset.val + ℓ']! + rw [h_unchanged] + exact h_a_undone ℓ' h_ge' hℓ'_lt + -- D. Untouched b-zone for ℓ' ≥ s.val. + · intro ℓ' hℓ'_ge hℓ'_lt + rw [hs_val] at hℓ'_ge + have hℓ'_gt_j : j.val < ℓ' := by omega + have h_ge' : j.val ≤ ℓ' := Nat.le_of_lt hℓ'_gt_j + have h_idx_lt_16 : b_offset.val + ℓ' < 16 := by omega + have h_ne_i : b_offset.val + ℓ' ≠ i.val := by rw [h_i_val_arith]; omega + have h_ne_i1 : b_offset.val + ℓ' ≠ i1.val := by rw [h_i1_val_arith]; omega + have h_unchanged : + step_out.1.val[b_offset.val + ℓ']! = acc.1.coefficients.val[b_offset.val + ℓ']! := + h_step_other (b_offset.val + ℓ') h_idx_lt_16 h_ne_i h_ne_i1 + show step_out.1.val[b_offset.val + ℓ']! = re.coefficients.val[b_offset.val + ℓ']! + rw [h_unchanged] + exact h_b_undone ℓ' h_ge' hℓ'_lt + -- E. Other lanes unchanged from re. + · intro k h_k_lt h_k_other + have h_ne_i : k ≠ i.val := by + rw [h_i_val_arith] + rcases h_k_other with h1 | ⟨h2a, h2b⟩ | h3 + · omega + · omega + · omega + have h_ne_i1 : k ≠ i1.val := by + rw [h_i1_val_arith] + rcases h_k_other with h1 | ⟨h2a, h2b⟩ | h3 + · omega + · omega + · omega + have h_unchanged : + step_out.1.val[k]! = acc.1.coefficients.val[k]! := + h_step_other k h_k_lt h_ne_i h_ne_i1 + show step_out.1.val[k]! = re.coefficients.val[k]! + rw [h_unchanged] + exact h_other k h_k_lt h_k_other + · -- None branch (j ≥ step_vec). + have hj_ge : j.val ≥ step_vec.val := Nat.not_lt.mp h_lt + have hj_eq : j.val = step_vec.val := by omega + have h_iter_none := iter_next_none_eq_gen j step_vec hj_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i a_offset b_offset { start := j, «end» := step_vec } acc.1 acc.2 + = .ok (done (acc.1, acc.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := j, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := j, «end» := step_vec } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_l3 h_body + show Layer4PlusInner.step_post re a_offset b_offset step_vec bnd j (.done acc) + unfold Layer4PlusInner.step_post + show (Layer4PlusInner.inv re a_offset b_offset step_vec bnd step_vec acc).holds + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + · intro ℓ' hℓ' ℓ hℓ; rw [← hj_eq] at hℓ'; exact h_a_done ℓ' hℓ' ℓ hℓ + · intro ℓ' hℓ' ℓ hℓ; rw [← hj_eq] at hℓ'; exact h_b_done ℓ' hℓ' ℓ hℓ + · intro ℓ' hℓ'_ge hℓ'_lt + rw [← hj_eq] at hℓ'_ge; exact h_a_undone ℓ' hℓ'_ge hℓ'_lt + · intro ℓ' hℓ'_ge hℓ'_lt + rw [← hj_eq] at hℓ'_ge; exact h_b_undone ℓ' hℓ'_ge hℓ'_lt + · exact h_other + +set_option maxHeartbeats 16000000 in +/-- Inner-loop Triple. Closes by `loop_range_spec_usize` + the step lemma. -/ +private theorem ntt_at_layer_4_plus_inner_loop_lemma + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (acc_in : Layer4PlusInner.Acc) + (zeta_i a_offset b_offset step_vec : Std.Usize) (bnd : Nat) (h_bnd : bnd ≤ 8 * 3328) + (h_zeta_lt : zeta_i.val < 128) + (h_step_vec_pos : 0 < step_vec.val) + (h_step_vec_le_16 : step_vec.val ≤ 16) + (h_a_disj : a_offset.val + step_vec.val ≤ b_offset.val) + (h_b_le_16 : b_offset.val + step_vec.val ≤ 16) + (h_pre_a : ∀ ℓ' : Nat, ℓ' < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[a_offset.val + ℓ']!).elements.val[ℓ]!).val.natAbs ≤ bnd) + (h_acc_in_eq : acc_in.1 = re) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := 0#usize, «end» := step_vec } zeta_i acc_in.1 acc_in.2 a_offset b_offset + ⦃ ⇓ p => ⌜ -- Both a-zone and b-zone fully processed. + (∀ ℓ' : Nat, ℓ' < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((p.1.coefficients.val[a_offset.val + ℓ']!).elements.val[ℓ]!).val.natAbs + ≤ bnd + 3328) + ∧ (∀ ℓ' : Nat, ℓ' < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((p.1.coefficients.val[b_offset.val + ℓ']!).elements.val[ℓ]!).val.natAbs + ≤ bnd + 3328) + -- Other lanes unchanged from re. + ∧ (∀ k : Nat, k < 16 → + (k < a_offset.val ∨ a_offset.val + step_vec.val ≤ k ∧ k < b_offset.val + ∨ b_offset.val + step_vec.val ≤ k) → + p.1.coefficients.val[k]! = re.coefficients.val[k]!) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0 + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i a_offset b_offset iter1 acc1.1 acc1.2) + (β := Layer4PlusInner.Acc) + acc_in + 0#usize step_vec + (Layer4PlusInner.inv re a_offset b_offset step_vec bnd) + (by scalar_tac : (0#usize : Std.Usize).val ≤ step_vec.val) + (pure_prop_holds_l3 + ⟨fun ℓ' hℓ' _ _ => absurd hℓ' (Nat.not_lt_zero _), + fun ℓ' hℓ' _ _ => absurd hℓ' (Nat.not_lt_zero _), + fun ℓ' _ _ => by rw [h_acc_in_eq], + fun ℓ' _ _ => by rw [h_acc_in_eq], + fun k _ _ => by rw [h_acc_in_eq]⟩) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_a_done, h_b_done, _h_a_undone, _h_b_undone, h_other⟩ := of_pure_prop_holds_l3 h + exact ⟨h_a_done, h_b_done, h_other⟩ + · -- Step lemma. + intro acc k h_ge h_le hinv + have h_step := ntt_at_layer_4_plus_inner_step_lemma + re zeta_i a_offset b_offset step_vec bnd h_bnd h_zeta_lt + ⟨h_a_disj, h_b_le_16⟩ h_pre_a acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer4PlusInner.step_post re a_offset b_offset step_vec bnd k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusInner.step_post] using hP + · have hP : Layer4PlusInner.step_post re a_offset b_offset step_vec bnd k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusInner.step_post] using hP + +/-! ### outer loop helper + +The outer loop `ntt_at_layer_4_plus_loop0` iterates `round ∈ [0, outer_count)`, +each iter increments `zeta_i` by 1 and calls the inner loop on the disjoint +lane-pair window `[a_offset, a_offset + step_vec) ∪ [b_offset, b_offset + step_vec)` +where `a_offset = round * 2 * step_vec`, `b_offset = a_offset + step_vec`. +The outer invariant after `round` iters: lanes `[0, 2*round*step_vec)` are +processed at `(B+1)*3328`; lanes `[2*round*step_vec, 16)` unchanged. + +We require `2 * outer_count * step_vec = 16` (the L3.4 caller invariant for +layer ∈ {4, 5, 6}). +-/ + +/-- Local helper: `x * y` reduces to `.ok z` with `z.val = x.val * y.val` under + no-overflow on `Usize`. Mirrors `usize_add_ok_eq` / `usize_div_ok_eq`. -/ +private theorem usize_mul_ok_eq (x y : Std.Usize) + (h_max : x.val * y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x * y : Result Std.Usize) = .ok z ∧ z.val = x.val * y.val := by + have hT := Std.Usize.mul_spec h_max + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + refine ⟨z, h_eq, ?_⟩ + show z.val = x.val * y.val + exact h_v + +namespace Layer4PlusOuter + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +/-- Outer-loop accumulator: `(zeta_i, PolynomialRingElement, scratch)`. -/ +abbrev Acc := Std.Usize × + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector × + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- Outer-loop invariant after `round` iters: lanes `[0, 2*round*step_vec)` + are processed; lanes `[2*round*step_vec, 16)` match `re`. -/ +def inv + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_init step_vec : Std.Usize) (bnd : Nat) : + Std.Usize → Acc → Result Prop := + fun round acc => pure ( + acc.1.val = zeta_i_init.val + round.val + ∧ (∀ k : Nat, k < 2 * round.val * step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.1.coefficients.val[k]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328) + ∧ (∀ k : Nat, 2 * round.val * step_vec.val ≤ k → k < 16 → + acc.2.1.coefficients.val[k]! = re.coefficients.val[k]!)) + +def step_post + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_init step_vec outer_count : Std.Usize) (bnd : Nat) (round : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + round.val < outer_count.val ∧ iter'.«end» = outer_count + ∧ iter'.start.val = round.val + 1 + ∧ (inv re zeta_i_init step_vec bnd iter'.start acc').holds + | .done y => (inv re zeta_i_init step_vec bnd outer_count y).holds + +end Layer4PlusOuter + +set_option maxHeartbeats 16000000 in +/-- Outer-loop step lemma. Each iter calls the inner loop on the window + `[2*round*step_vec, (2*round+2)*step_vec)`, transforming all 2*step_vec + lanes in that window from `bnd` to `bnd + 3328`. -/ +private theorem ntt_at_layer_4_plus_outer_step_lemma + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_init step_vec outer_count : Std.Usize) (bnd : Nat) (h_bnd : bnd ≤ 8 * 3328) + (h_step_vec_pos : 0 < step_vec.val) + (h_step_vec_le_16 : step_vec.val ≤ 16) + (h_outer_count_pos : 0 < outer_count.val) + (h_two_oc_sv_eq : 2 * outer_count.val * step_vec.val = 16) + (h_zeta_init_lt : zeta_i_init.val + outer_count.val < 128) + (h_pre : ∀ i : Nat, i < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[i]!).elements.val[ℓ]!).val.natAbs ≤ bnd) + (acc : Layer4PlusOuter.Acc) + (round : Std.Usize) (h_le : round.val ≤ outer_count.val) + (hinv : (Layer4PlusOuter.inv re zeta_i_init step_vec bnd round acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + step_vec { start := round, «end» := outer_count } acc.1 acc.2.1 acc.2.2 + ⦃ ⇓ r => ⌜ Layer4PlusOuter.step_post re zeta_i_init step_vec outer_count bnd round r ⌝ ⦄ := by + obtain ⟨h_zeta_acc, h_done, h_undone⟩ := of_pure_prop_holds_l3 hinv + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + by_cases h_lt : round.val < outer_count.val + · -- Some round branch. + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq_gen round outer_count h_lt + -- 1) zeta_i1 = zeta_i + 1. + have h_um : (1#usize : Std.Usize).val = 1 := rfl + have h_um2 : (2#usize : Std.Usize).val = 2 := rfl + have h_zeta_lt : acc.1.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um] + have : acc.1.val + 1 ≤ 128 := by rw [h_zeta_acc]; omega + have : (128 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨zi1, h_zi1_eq, h_zi1_val⟩ := usize_add_ok_eq acc.1 1#usize h_zeta_lt + have h_zi1_val_arith : zi1.val = acc.1.val + 1 := by rw [h_zi1_val, h_um] + have h_zi1_lt : zi1.val < 128 := by + rw [h_zi1_val_arith, h_zeta_acc]; omega + -- 2) i = round * 2. Bound: round.val < outer_count.val; round * 2 ≤ 16. + have h_round_2_max : round.val * (2#usize : Std.Usize).val ≤ Std.Usize.max := by + rw [h_um2] + have : round.val * 2 ≤ outer_count.val * 2 := Nat.mul_le_mul_right 2 (Nat.le_of_lt h_lt) + have : outer_count.val * 2 ≤ 16 := by + have hh : 2 * outer_count.val * step_vec.val = 16 := h_two_oc_sv_eq + -- 2 * o ≤ 2 * o * s = 16 (using s ≥ 1 via h_step_vec_pos) + have : 2 * outer_count.val ≤ 2 * outer_count.val * step_vec.val := + Nat.le_mul_of_pos_right _ h_step_vec_pos + omega + have : (16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨ri2, h_ri2_eq, h_ri2_val⟩ := usize_mul_ok_eq round 2#usize h_round_2_max + have h_ri2_val_arith : ri2.val = round.val * 2 := by rw [h_ri2_val, h_um2] + -- 3) a_offset = ri2 * step_vec. Bound: ri2.val * step_vec.val ≤ 16. + have h_ri2_lt_oc2 : ri2.val ≤ outer_count.val * 2 := by + rw [h_ri2_val_arith]; exact Nat.mul_le_mul_right 2 (Nat.le_of_lt h_lt) + have h_oc2_sv : outer_count.val * 2 * step_vec.val = 16 := by + have := h_two_oc_sv_eq; grind + have h_ri2_sv_le_16 : ri2.val * step_vec.val ≤ 16 := by + calc ri2.val * step_vec.val ≤ (outer_count.val * 2) * step_vec.val := + Nat.mul_le_mul_right _ h_ri2_lt_oc2 + _ = 16 := h_oc2_sv + have h_ao_max : ri2.val * step_vec.val ≤ Std.Usize.max := by + have : (16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨a_off, h_ao_eq, h_ao_val⟩ := usize_mul_ok_eq ri2 step_vec h_ao_max + -- 4) b_offset = a_off + step_vec. Bound: a_off + step_vec ≤ 16. + have h_ao_val_arith : a_off.val = ri2.val * step_vec.val := h_ao_val + have h_ao_eq_2rsv : a_off.val = 2 * round.val * step_vec.val := by + rw [h_ao_val_arith, h_ri2_val_arith]; ring + have h_bo_max : a_off.val + step_vec.val ≤ Std.Usize.max := by + have : a_off.val + step_vec.val ≤ 16 := by + rw [h_ao_eq_2rsv] + have hh : 2 * (round.val + 1) * step_vec.val ≤ 16 := by + have := h_two_oc_sv_eq + calc 2 * (round.val + 1) * step_vec.val + ≤ 2 * outer_count.val * step_vec.val := by + apply Nat.mul_le_mul_right; omega + _ = 16 := h_two_oc_sv_eq + grind + have : (16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨b_off, h_bo_eq, h_bo_val⟩ := usize_add_ok_eq a_off step_vec h_bo_max + have h_bo_val_arith : b_off.val = a_off.val + step_vec.val := h_bo_val + -- Disjointness: a_off + step_vec ≤ b_off (definitional). + have h_a_disj : a_off.val + step_vec.val ≤ b_off.val := by + rw [h_bo_val_arith] + -- b_off + step_vec ≤ 16 (definitional via 2*(round+1)*step_vec ≤ 16). + have h_b_le_16 : b_off.val + step_vec.val ≤ 16 := by + rw [h_bo_val_arith, h_ao_eq_2rsv] + have hh : 2 * (round.val + 1) * step_vec.val ≤ 16 := by + calc 2 * (round.val + 1) * step_vec.val + ≤ 2 * outer_count.val * step_vec.val := by + apply Nat.mul_le_mul_right; omega + _ = 16 := h_two_oc_sv_eq + grind + -- 5) Apply inner loop spec, using acc.2.1 as the "re" of the inner call. + -- We need to prove that the inner-loop precondition holds on acc.2.1's + -- window `[a_off, a_off+step_vec)` (a-side bound only). + -- For `ℓ' < step_vec`, the lane `a_off + ℓ' = 2*round*step_vec + ℓ'` is + -- within `[2*round*step_vec, (2*round+1)*step_vec) ⊆ [2*round*step_vec, 16)`. + -- The outer invariant's `h_undone` gives `acc.2.1[k] = re[k]` for those. + have h_pre_a_inner : ∀ ℓ' : Nat, ℓ' < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.2.1.coefficients.val[a_off.val + ℓ']!).elements.val[ℓ]!).val.natAbs ≤ bnd := by + intro ℓ' hℓ' ℓ hℓ + -- a_off + ℓ' = 2*round*step_vec + ℓ' ≥ 2*round*step_vec (undone) + -- a_off + ℓ' ≤ 2*round*step_vec + step_vec - 1 < 16. + have h_idx_ge : 2 * round.val * step_vec.val ≤ a_off.val + ℓ' := by + rw [h_ao_eq_2rsv]; omega + have h_idx_lt : a_off.val + ℓ' < 16 := by + rw [h_ao_eq_2rsv] + have hh : 2 * (round.val + 1) * step_vec.val ≤ 16 := by + calc 2 * (round.val + 1) * step_vec.val + ≤ 2 * outer_count.val * step_vec.val := by + apply Nat.mul_le_mul_right; omega + _ = 16 := h_two_oc_sv_eq + grind + have h_eq : acc.2.1.coefficients.val[a_off.val + ℓ']! + = re.coefficients.val[a_off.val + ℓ']! := + h_undone (a_off.val + ℓ') h_idx_ge h_idx_lt + rw [h_eq] + exact h_pre (a_off.val + ℓ') h_idx_lt ℓ hℓ + -- The inner loop spec. + have h_inner_spec : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0_loop0 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := 0#usize, «end» := step_vec } zi1 acc.2.1 acc.2.2 a_off b_off + ⦃ ⇓ p => ⌜ + (∀ ℓ' : Nat, ℓ' < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((p.1.coefficients.val[a_off.val + ℓ']!).elements.val[ℓ]!).val.natAbs + ≤ bnd + 3328) + ∧ (∀ ℓ' : Nat, ℓ' < step_vec.val → ∀ ℓ : Nat, ℓ < 16 → + ((p.1.coefficients.val[b_off.val + ℓ']!).elements.val[ℓ]!).val.natAbs + ≤ bnd + 3328) + ∧ (∀ k : Nat, k < 16 → + (k < a_off.val ∨ a_off.val + step_vec.val ≤ k ∧ k < b_off.val + ∨ b_off.val + step_vec.val ≤ k) → + p.1.coefficients.val[k]! = acc.2.1.coefficients.val[k]!) ⌝ ⦄ := + ntt_at_layer_4_plus_inner_loop_lemma + acc.2.1 acc.2 zi1 a_off b_off step_vec bnd h_bnd h_zi1_lt + h_step_vec_pos h_step_vec_le_16 h_a_disj h_b_le_16 + h_pre_a_inner rfl + obtain ⟨inner_out, h_inner_eq, h_inner_post⟩ := triple_exists_ok_l3 h_inner_spec + obtain ⟨h_inner_a_bd, h_inner_b_bd, h_inner_other⟩ := h_inner_post + -- Next-state. + set acc' : Layer4PlusOuter.Acc := (zi1, inner_out.1, inner_out.2) with hacc'_def + -- Compose the body into one .ok. + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + step_vec { start := round, «end» := outer_count } acc.1 acc.2.1 acc.2.2 + = .ok (cont (({ start := s, «end» := outer_count } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := round, «end» := outer_count } + : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := round, «end» := outer_count } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [bind_tc_ok, h_zi1_eq, h_ri2_eq, h_ao_eq, h_bo_eq, h_inner_eq, hacc'_def] + rfl + apply triple_of_ok_l3 h_body + show Layer4PlusOuter.step_post re zeta_i_init step_vec outer_count bnd round + (.cont (({ start := s, «end» := outer_count } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold Layer4PlusOuter.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + -- Outer invariant after this iter. + -- A. acc'.1.val = zeta_i_init + s.val = zeta_i_init + round.val + 1. + · show zi1.val = zeta_i_init.val + s.val + rw [h_zi1_val_arith, h_zeta_acc, hs_val]; ring + -- B. Lanes [0, 2*s.val*step_vec) are processed at (B+1)*3328. + · intro k hk ℓ hℓ + -- 2*s*step_vec = 2*(round+1)*step_vec = 2*round*step_vec + 2*step_vec + -- = a_off + 2*step_vec = b_off + step_vec. + have h_2s_eq : 2 * s.val * step_vec.val + = b_off.val + step_vec.val := by + rw [hs_val, h_bo_val_arith, h_ao_eq_2rsv]; ring + -- Cases: k ∈ [0, a_off) [already processed]; k ∈ [a_off, a_off+step_vec) [inner a-zone]; + -- k ∈ [a_off+step_vec, b_off) [empty — a_disj]; k ∈ [b_off, b_off+step_vec) [inner b-zone]. + by_cases h_k_lt_a : k < a_off.val + · -- Already processed before this iter. Use h_inner_other + h_done. + have h_k_other : k < a_off.val ∨ a_off.val + step_vec.val ≤ k ∧ k < b_off.val + ∨ b_off.val + step_vec.val ≤ k := Or.inl h_k_lt_a + have h_k_lt_16 : k < 16 := by + rw [h_2s_eq] at hk + have : a_off.val ≤ b_off.val + step_vec.val := by + rw [h_bo_val_arith]; omega + omega + have h_eq : inner_out.1.coefficients.val[k]! = acc.2.1.coefficients.val[k]! := + h_inner_other k h_k_lt_16 h_k_other + show (inner_out.1.coefficients.val[k]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_eq] + -- k < a_off = 2*round*step_vec ⇒ k is in done zone of outer inv. + have h_k_lt_2rsv : k < 2 * round.val * step_vec.val := by + rw [h_ao_eq_2rsv] at h_k_lt_a; exact h_k_lt_a + exact h_done k h_k_lt_2rsv ℓ hℓ + · -- k ≥ a_off. Either in a-window, gap (empty), or b-window. + have h_k_ge_a : a_off.val ≤ k := Nat.not_lt.mp h_k_lt_a + by_cases h_k_lt_aps : k < a_off.val + step_vec.val + · -- In a-window: k - a_off < step_vec; apply h_inner_a_bd. + set ℓ' := k - a_off.val with hℓ'_def + have hℓ'_lt : ℓ' < step_vec.val := by omega + have h_k_eq : a_off.val + ℓ' = k := by omega + have := h_inner_a_bd ℓ' hℓ'_lt ℓ hℓ + rw [h_k_eq] at this + exact this + · -- k ≥ a_off + step_vec = b_off. Either in b-window or beyond. + have h_k_ge_aps : a_off.val + step_vec.val ≤ k := Nat.not_lt.mp h_k_lt_aps + have h_k_ge_b : b_off.val ≤ k := by rw [h_bo_val_arith]; exact h_k_ge_aps + have h_k_lt_bps : k < b_off.val + step_vec.val := by + have : k < 2 * s.val * step_vec.val := hk + rw [h_2s_eq] at this; exact this + -- In b-window: k - b_off < step_vec; apply h_inner_b_bd. + set ℓ' := k - b_off.val with hℓ'_def + have hℓ'_lt : ℓ' < step_vec.val := by omega + have h_k_eq : b_off.val + ℓ' = k := by omega + have := h_inner_b_bd ℓ' hℓ'_lt ℓ hℓ + rw [h_k_eq] at this + exact this + -- C. Lanes [2*s.val*step_vec, 16) match re. + · intro k hk_ge hk_lt + -- 2*s*step_vec = b_off + step_vec; so k ≥ b_off + step_vec. + have h_2s_eq : 2 * s.val * step_vec.val = b_off.val + step_vec.val := by + rw [hs_val, h_bo_val_arith, h_ao_eq_2rsv]; ring + have h_k_ge_bps : b_off.val + step_vec.val ≤ k := by rw [← h_2s_eq]; exact hk_ge + have h_k_other : k < a_off.val ∨ a_off.val + step_vec.val ≤ k ∧ k < b_off.val + ∨ b_off.val + step_vec.val ≤ k := Or.inr (Or.inr h_k_ge_bps) + have h_eq1 : inner_out.1.coefficients.val[k]! = acc.2.1.coefficients.val[k]! := + h_inner_other k hk_lt h_k_other + show inner_out.1.coefficients.val[k]! = re.coefficients.val[k]! + rw [h_eq1] + -- k ≥ b_off + step_vec ≥ 2*round*step_vec (since round ≤ outer_count - 1, ...) — actually + -- this isn't trivially related to outer h_undone. Let's compute: + -- k ≥ b_off + step_vec = 2*s*step_vec ≥ 2*round*step_vec. + have h_k_ge_2rsv : 2 * round.val * step_vec.val ≤ k := by + have h_aux : 2 * round.val * step_vec.val ≤ b_off.val + step_vec.val := by + rw [h_bo_val_arith, h_ao_eq_2rsv]; omega + omega + exact h_undone k h_k_ge_2rsv hk_lt + · -- None branch (round ≥ outer_count). + have hr_ge : round.val ≥ outer_count.val := Nat.not_lt.mp h_lt + have hr_eq : round.val = outer_count.val := by omega + have h_iter_none := iter_next_none_eq_gen round outer_count hr_ge + have h_body : + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + step_vec { start := round, «end» := outer_count } acc.1 acc.2.1 acc.2.2 + = .ok (done (acc.1, acc.2.1, acc.2.2)) := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := round, «end» := outer_count } + : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := round, «end» := outer_count } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + have h_acc_eq : (acc.1, acc.2.1, acc.2.2) = acc := rfl + rw [h_acc_eq] at h_body + apply triple_of_ok_l3 h_body + show Layer4PlusOuter.step_post re zeta_i_init step_vec outer_count bnd round (.done acc) + unfold Layer4PlusOuter.step_post + show (Layer4PlusOuter.inv re zeta_i_init step_vec bnd outer_count acc).holds + apply pure_prop_holds_l3 + refine ⟨?_, ?_, ?_⟩ + · rw [h_zeta_acc, hr_eq] + · intro k hk ℓ hℓ; rw [← hr_eq] at hk; exact h_done k hk ℓ hℓ + · intro k hk_ge hk_lt; rw [← hr_eq] at hk_ge; exact h_undone k hk_ge hk_lt + +set_option maxHeartbeats 16000000 in +/-- Outer-loop closure via `loop_range_spec_usize` + the outer step lemma. -/ +private theorem ntt_at_layer_4_plus_outer_loop_lemma + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta_i_init step_vec outer_count : Std.Usize) (bnd : Nat) (h_bnd : bnd ≤ 8 * 3328) + (h_step_vec_pos : 0 < step_vec.val) + (h_step_vec_le_16 : step_vec.val ≤ 16) + (h_outer_count_pos : 0 < outer_count.val) + (h_two_oc_sv_eq : 2 * outer_count.val * step_vec.val = 16) + (h_zeta_init_lt : zeta_i_init.val + outer_count.val < 128) + (h_pre : ∀ i : Nat, i < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((re.coefficients.val[i]!).elements.val[ℓ]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0 + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := 0#usize, «end» := outer_count } zeta_i_init re scratch step_vec + ⦃ ⇓ p => ⌜ p.1.val = zeta_i_init.val + outer_count.val + ∧ ∀ i : Nat, i < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((p.2.1.coefficients.val[i]!).elements.val[ℓ]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0 + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus_loop0.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + step_vec iter1 acc1.1 acc1.2.1 acc1.2.2) + (β := Layer4PlusOuter.Acc) + (zeta_i_init, re, scratch) + 0#usize outer_count + (Layer4PlusOuter.inv re zeta_i_init step_vec bnd) + (by scalar_tac : (0#usize : Std.Usize).val ≤ outer_count.val) + (pure_prop_holds_l3 + ⟨by show zeta_i_init.val = zeta_i_init.val + (0#usize : Std.Usize).val; rfl, + fun k hk _ _ => by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0] at hk; omega, + fun _ _ _ => rfl⟩) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_zeta_eq, h_done, _h_undone⟩ := of_pure_prop_holds_l3 h + refine ⟨h_zeta_eq, ?_⟩ + intro i hi ℓ hℓ + have h16 : 2 * outer_count.val * step_vec.val = 16 := h_two_oc_sv_eq + apply h_done i (by rw [h16]; exact hi) ℓ hℓ + · -- Step lemma. + intro acc k h_ge h_le hinv + have h_step := ntt_at_layer_4_plus_outer_step_lemma re zeta_i_init step_vec outer_count + bnd h_bnd h_step_vec_pos h_step_vec_le_16 h_outer_count_pos h_two_oc_sv_eq h_zeta_init_lt + h_pre acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : Layer4PlusOuter.step_post re zeta_i_init step_vec outer_count bnd k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusOuter.step_post] using hP + · have hP : Layer4PlusOuter.step_post re zeta_i_init step_vec outer_count bnd k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [Layer4PlusOuter.step_post] using hP + +/-! ### top-level `ntt_at_layer_4_plus_spec` -/ + +set_option maxHeartbeats 32000000 in +@[spec] +theorem ntt_at_layer_4_plus_spec + (layer zeta_i : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (bnd : Std.Usize) + (h_layer : 4 ≤ layer.val ∧ layer.val ≤ 7) + (h_bnd : bnd.val ≤ 8 * 3328) + (h_zeta : zeta_i.val = (1 <<< (7 - layer.val)) - 1) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd.val) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + zeta_i re layer scratch bnd + ⦃ ⇓ p => ⌜ p.1.val = zeta_i.val + 128 >>> layer.val + ∧ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.2.1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ bnd.val + 3328 ⌝ ⦄ := by + obtain ⟨h_layer_ge, h_layer_le⟩ := h_layer + unfold libcrux_iot_ml_kem.ntt.ntt_at_layer_4_plus + -- Compute step = 1 <<< layer. + have h_layer_lt_numBits : layer.val < Std.UScalarTy.Usize.numBits := by + show layer.val < System.Platform.numBits + rcases System.Platform.numBits_eq with h32 | h64 + · rw [h32]; omega + · rw [h64]; omega + have h_shl_spec := + Std.Usize.ShiftLeft_spec (1#usize : Std.Usize) layer h_layer_lt_numBits + obtain ⟨step, h_step_eq, h_step_props⟩ := Std.WP.spec_imp_exists h_shl_spec + have h_step_val : step.val = ((1#usize : Std.Usize).val <<< layer.val) % Std.Usize.size := + h_step_props.1 + have h_1u_val : (1#usize : Std.Usize).val = 1 := rfl + have h_step_val_clean : step.val = (1 <<< layer.val) % Std.Usize.size := by + rw [h_step_val, h_1u_val] + -- Usize.size ≥ 2^32 > 128 ≥ 1 <<< 7, so modulus is identity. + have h_size_ge : (1 <<< layer.val : Nat) < Std.Usize.size := by + have h_pow : (1 <<< layer.val : Nat) = 2 ^ layer.val := by + rw [Nat.shiftLeft_eq, Nat.one_mul] + rw [h_pow] + have h_le_128 : (2 : Nat) ^ layer.val ≤ 2 ^ 7 := Nat.pow_le_pow_right (by omega) h_layer_le + have h_128_lt : (128 : Nat) < Std.Usize.size := by + have h_min : Std.Usize.size ≥ 2 ^ 32 := by scalar_tac + have : (128 : Nat) < 2 ^ 32 := by decide + omega + have : (2 : Nat) ^ 7 = 128 := by decide + omega + have h_step_val_eq : step.val = 1 <<< layer.val := by + rw [h_step_val_clean]; exact Nat.mod_eq_of_lt h_size_ge + -- step_vec = step / 16. + have h_16_nz : (16#usize : Std.Usize).val ≠ 0 := by decide + obtain ⟨step_vec, h_sv_eq, h_sv_val⟩ := usize_div_ok_eq step 16#usize h_16_nz + have h_sv_val_clean : step_vec.val = (1 <<< layer.val) / 16 := by + rw [h_sv_val, h_step_val_eq] + show (1 <<< layer.val) / (16#usize : Std.Usize).val = (1 <<< layer.val) / 16 + rfl + -- outer_count = 128 >>> layer. + have h_shr_spec := + Std.Usize.ShiftRight_spec (128#usize : Std.Usize) layer h_layer_lt_numBits + obtain ⟨outer_count, h_oc_eq, h_oc_props⟩ := Std.WP.spec_imp_exists h_shr_spec + have h_oc_val : outer_count.val = (128#usize : Std.Usize).val >>> layer.val := h_oc_props.1 + have h_128u_val : (128#usize : Std.Usize).val = 128 := rfl + have h_oc_val_clean : outer_count.val = 128 >>> layer.val := by + rw [h_oc_val, h_128u_val] + -- Per-layer arithmetic. + have h_two_oc_sv_eq : 2 * outer_count.val * step_vec.val = 16 := by + rw [h_oc_val_clean, h_sv_val_clean] + interval_cases layer.val <;> decide + have h_outer_count_pos : 0 < outer_count.val := by + rw [h_oc_val_clean] + interval_cases layer.val <;> decide + have h_step_vec_pos : 0 < step_vec.val := by + rw [h_sv_val_clean] + interval_cases layer.val <;> decide + have h_step_vec_le_16 : step_vec.val ≤ 16 := by + rw [h_sv_val_clean] + interval_cases layer.val <;> decide + have h_zeta_init_lt : zeta_i.val + outer_count.val < 128 := by + rw [h_zeta, h_oc_val_clean] + interval_cases layer.val <;> decide + -- The h_pre bound at `bnd.val` is exactly the precondition we need for the + -- outer loop lemma (with `bnd := bnd.val`). + rw [h_step_eq] + simp only [bind_tc_ok] + rw [h_sv_eq] + simp only [bind_tc_ok] + rw [h_oc_eq] + simp only [bind_tc_ok] + -- Apply outer loop lemma. + have h_outer := + ntt_at_layer_4_plus_outer_loop_lemma re scratch zeta_i step_vec outer_count + bnd.val h_bnd h_step_vec_pos h_step_vec_le_16 h_outer_count_pos + h_two_oc_sv_eq h_zeta_init_lt h_pre + apply Std.Do.Triple.of_entails_right _ h_outer + rw [PostCond.entails_noThrow] + intro r h + -- h : (⌜post⌝).down — a plain Prop, not a pure_prop_holds. + obtain ⟨h_zeta_post, h_bd_post⟩ := h + refine ⟨?_, ?_⟩ + · -- h_zeta_post : r.1.val = zeta_i.val + outer_count.val. + -- outer_count.val = 128 >>> layer.val (h_oc_val_clean). + rw [h_zeta_post, h_oc_val_clean] + · intro i hi j hj + exact h_bd_post i hi j hj + +/-! ## L3.6 — `ntt_binomially_sampled_ring_element_spec` + +Composes the eight forward-NTT driver stages plus the terminal +`poly_barrett_reduce`: + + L3.5 → L3.4(layer=6) → L3.4(layer=5) → L3.4(layer=4) + → L3.3_B → L3.2_B → L3.1_B → L6.1 + +Bound cascade (per coefficient): + ≤ 3 → ≤ 4803 → ≤ 14535 → ≤ 17863 → ≤ 21191 + → ≤ 24519 → ≤ 27847 → ≤ 31175 → ≤ 3328. + +Implements the §13.4 "independent equation chains" pattern: each step's +`.ok`-equation derived independently via `triple_exists_ok_l3`, then +composed via `rw` against the unfolded impl `do`-block. -/ + +set_option maxHeartbeats 32000000 in +@[spec] +theorem ntt_binomially_sampled_ring_element_spec + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_binomially_sampled_ring_element + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + re scratch + ⦃ ⇓ p => ⌜ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328 ⌝ ⦄ := by + -- ============================================================ + -- Step 1: L3.5 (ntt_at_layer_7). re → re1, |re1| ≤ 4803. + -- ============================================================ + obtain ⟨⟨re1, scratch1⟩, h_step1_eq, h_re1_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_7_spec re scratch h_pre) + dsimp only at h_re1_bd + -- ============================================================ + -- Step 2: L3.4(layer=6, zeta_i=1, bnd=11207). re1 → re2. + -- zeta_i out: 1 + 128 >>> 6 = 1 + 2 = 3. |re2| ≤ 14535. + -- ============================================================ + have h_re1_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re1.coefficients.val[i]!).elements.val[j]!).val.natAbs + ≤ (11207#usize : Std.Usize).val := by + intro i hi j hj + have hb := h_re1_bd i hi j hj + have : (11207#usize : Std.Usize).val = 11207 := rfl + omega + obtain ⟨⟨zeta2, re2, scratch2⟩, h_step2_eq, h_zeta2_val, h_re2_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_4_plus_spec + (layer := 6#usize) (zeta_i := 1#usize) re1 scratch1 11207#usize + (by decide : 4 ≤ (6#usize : Std.Usize).val ∧ (6#usize : Std.Usize).val ≤ 7) + (by decide : (11207#usize : Std.Usize).val ≤ 8 * 3328) + (by decide : + (1#usize : Std.Usize).val = (1 <<< (7 - (6#usize : Std.Usize).val)) - 1) + h_re1_loose) + dsimp only at h_zeta2_val h_re2_bd + have h_zeta2_eq3 : zeta2.val = 3 := by + have : (1#usize : Std.Usize).val = 1 := rfl + have h6 : (6#usize : Std.Usize).val = 6 := rfl + rw [h_zeta2_val, this, h6]; decide + have h_re2_bd' : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 14535 := by + intro i hi j hj + have hb := h_re2_bd i hi j hj + have : (11207#usize : Std.Usize).val = 11207 := rfl + omega + -- ============================================================ + -- Step 2.5: i ← 11207 + 3328 = 14535. + -- ============================================================ + have h_add1_max : + (11207#usize : Std.Usize).val + (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (11207#usize : Std.Usize).val = 11207 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i14535, h_i14535_eq, h_i14535_val⟩ := + usize_add_ok_eq (11207#usize : Std.Usize) (3328#usize : Std.Usize) h_add1_max + have h_i14535_eq_val : i14535.val = 14535 := by + rw [h_i14535_val]; decide + -- ============================================================ + -- Step 3: L3.4(layer=5, zeta_i=3, bnd=14535). re2 → re3. + -- zeta_i out: 3 + 128 >>> 5 = 3 + 4 = 7. |re3| ≤ 17863. + -- ============================================================ + have h_re2_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ i14535.val := by + intro i hi j hj + have hb := h_re2_bd' i hi j hj + omega + obtain ⟨⟨zeta3, re3, scratch3⟩, h_step3_eq, h_zeta3_val, h_re3_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_4_plus_spec + (layer := 5#usize) (zeta_i := zeta2) re2 scratch2 i14535 + (by decide : 4 ≤ (5#usize : Std.Usize).val ∧ (5#usize : Std.Usize).val ≤ 7) + (by + have h5 : (5#usize : Std.Usize).val = 5 := rfl + rw [h_i14535_eq_val]; decide) + (by + have h5 : (5#usize : Std.Usize).val = 5 := rfl + rw [h_zeta2_eq3, h5]; decide) + h_re2_loose) + dsimp only at h_zeta3_val h_re3_bd + have h_zeta3_eq7 : zeta3.val = 7 := by + have h5 : (5#usize : Std.Usize).val = 5 := rfl + rw [h_zeta3_val, h_zeta2_eq3, h5]; decide + have h_re3_bd' : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re3.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 17863 := by + intro i hi j hj + have hb := h_re3_bd i hi j hj + omega + -- ============================================================ + -- Step 3.5: i1 ← 2 * 3328 = 6656, i2 ← 11207 + i1 = 17863. + -- ============================================================ + have h_mul1_max : + (2#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (2#usize : Std.Usize).val = 2 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i6656, h_i6656_eq, h_i6656_val⟩ := + usize_mul_ok_eq (2#usize : Std.Usize) (3328#usize : Std.Usize) h_mul1_max + have h_i6656_eq_val : i6656.val = 6656 := by + rw [h_i6656_val]; decide + have h_add2_max : + (11207#usize : Std.Usize).val + i6656.val ≤ Std.Usize.max := by + have : (11207#usize : Std.Usize).val = 11207 := rfl + rw [this, h_i6656_eq_val]; scalar_tac + obtain ⟨i17863, h_i17863_eq, h_i17863_val⟩ := + usize_add_ok_eq (11207#usize : Std.Usize) i6656 h_add2_max + have h_i17863_eq_val : i17863.val = 17863 := by + rw [h_i17863_val, h_i6656_eq_val]; decide + -- ============================================================ + -- Step 4: L3.4(layer=4, zeta_i=7, bnd=17863). re3 → re4. + -- zeta_i out: 7 + 128 >>> 4 = 7 + 8 = 15. |re4| ≤ 21191. + -- ============================================================ + have h_re3_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re3.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ i17863.val := by + intro i hi j hj + have hb := h_re3_bd' i hi j hj + omega + obtain ⟨⟨zeta4, re4, scratch4⟩, h_step4_eq, h_zeta4_val, h_re4_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_4_plus_spec + (layer := 4#usize) (zeta_i := zeta3) re3 scratch3 i17863 + (by decide : 4 ≤ (4#usize : Std.Usize).val ∧ (4#usize : Std.Usize).val ≤ 7) + (by rw [h_i17863_eq_val]; decide) + (by + have h4 : (4#usize : Std.Usize).val = 4 := rfl + rw [h_zeta3_eq7, h4]; decide) + h_re3_loose) + dsimp only at h_zeta4_val h_re4_bd + have h_zeta4_eq15 : zeta4.val = 15 := by + have h4 : (4#usize : Std.Usize).val = 4 := rfl + rw [h_zeta4_val, h_zeta3_eq7, h4]; decide + have h_re4_bd' : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re4.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 21191 := by + intro i hi j hj + have hb := h_re4_bd i hi j hj + have h_iv : i17863.val = 17863 := h_i17863_eq_val + omega + -- ============================================================ + -- Step 4.5: i3 ← 3 * 3328 = 9984, i4 ← 11207 + i3 = 21191. + -- ============================================================ + have h_mul2_max : + (3#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (3#usize : Std.Usize).val = 3 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i9984, h_i9984_eq, h_i9984_val⟩ := + usize_mul_ok_eq (3#usize : Std.Usize) (3328#usize : Std.Usize) h_mul2_max + have h_i9984_eq_val : i9984.val = 9984 := by + rw [h_i9984_val]; decide + have h_add3_max : + (11207#usize : Std.Usize).val + i9984.val ≤ Std.Usize.max := by + have : (11207#usize : Std.Usize).val = 11207 := rfl + rw [this, h_i9984_eq_val]; scalar_tac + obtain ⟨i21191, h_i21191_eq, h_i21191_val⟩ := + usize_add_ok_eq (11207#usize : Std.Usize) i9984 h_add3_max + have h_i21191_eq_val : i21191.val = 21191 := by + rw [h_i21191_val, h_i9984_eq_val]; decide + -- ============================================================ + -- Step 5: L3.3_B(zeta_i=15, bnd=21191). re4 → re5. + -- zeta_i out: 31. |re5| ≤ 24519. + -- ============================================================ + have h_re4_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re4.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 21191 := h_re4_bd' + obtain ⟨⟨zeta5, re5⟩, h_step5_eq, h_zeta5_val, h_re5_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_3_spec_B + (zeta_i := zeta4) re4 i21191 + (bnd := 21191) (h_bnd := by decide) + (h_zeta := h_zeta4_eq15) + h_re4_loose) + dsimp only at h_zeta5_val h_re5_bd + -- ============================================================ + -- Step 5.5: i5 ← 4 * 3328 = 13312, i6 ← 11207 + i5 = 24519. + -- ============================================================ + have h_mul3_max : + (4#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (4#usize : Std.Usize).val = 4 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i13312, h_i13312_eq, h_i13312_val⟩ := + usize_mul_ok_eq (4#usize : Std.Usize) (3328#usize : Std.Usize) h_mul3_max + have h_i13312_eq_val : i13312.val = 13312 := by + rw [h_i13312_val]; decide + have h_add4_max : + (11207#usize : Std.Usize).val + i13312.val ≤ Std.Usize.max := by + have : (11207#usize : Std.Usize).val = 11207 := rfl + rw [this, h_i13312_eq_val]; scalar_tac + obtain ⟨i24519, h_i24519_eq, h_i24519_val⟩ := + usize_add_ok_eq (11207#usize : Std.Usize) i13312 h_add4_max + have h_i24519_eq_val : i24519.val = 24519 := by + rw [h_i24519_val, h_i13312_eq_val]; decide + -- ============================================================ + -- Step 6: L3.2_B(zeta_i=31, bnd=24519). re5 → re6. + -- zeta_i out: 63. |re6| ≤ 27847. + -- ============================================================ + obtain ⟨⟨zeta6, re6⟩, h_step6_eq, h_zeta6_val, h_re6_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_2_spec_B + (zeta_i := zeta5) re5 i24519 + (bnd := 24519) (h_bnd := by decide) + (h_zeta := h_zeta5_val) + h_re5_bd) + dsimp only at h_zeta6_val h_re6_bd + -- ============================================================ + -- Step 6.5: i7 ← 5 * 3328 = 16640, i8 ← 11207 + i7 = 27847. + -- ============================================================ + have h_mul4_max : + (5#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (5#usize : Std.Usize).val = 5 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i16640, h_i16640_eq, h_i16640_val⟩ := + usize_mul_ok_eq (5#usize : Std.Usize) (3328#usize : Std.Usize) h_mul4_max + have h_i16640_eq_val : i16640.val = 16640 := by + rw [h_i16640_val]; decide + have h_add5_max : + (11207#usize : Std.Usize).val + i16640.val ≤ Std.Usize.max := by + have : (11207#usize : Std.Usize).val = 11207 := rfl + rw [this, h_i16640_eq_val]; scalar_tac + obtain ⟨i27847, h_i27847_eq, h_i27847_val⟩ := + usize_add_ok_eq (11207#usize : Std.Usize) i16640 h_add5_max + have h_i27847_eq_val : i27847.val = 27847 := by + rw [h_i27847_val, h_i16640_eq_val]; decide + -- ============================================================ + -- Step 7: L3.1_B(zeta_i=63, bnd=27847). re6 → re7. + -- zeta_i out: 127. |re7| ≤ 31175. + -- ============================================================ + obtain ⟨⟨zeta7, re7⟩, h_step7_eq, _h_zeta7_val, h_re7_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_1_spec_B + (zeta_i := zeta6) re6 i27847 + (bnd := 27847) (h_bnd := by decide) + (h_zeta := h_zeta6_val) + h_re6_bd) + dsimp only at h_re7_bd + -- ============================================================ + -- Step 8: L6.1 poly_barrett_reduce. re7 → re8, |re8| ≤ 3328. + -- ============================================================ + have h_re7_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re7.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 32767 := by + intro i hi j hj + have hb := h_re7_bd i hi j hj + omega + obtain ⟨re8, h_step8_eq, h_re8_bd⟩ := + triple_exists_ok_l3 (PolynomialRingElement_poly_barrett_reduce_spec re7 h_re7_loose) + -- ============================================================ + -- Compose: derive the full impl `do`-block equation. + -- ============================================================ + have h_body : + libcrux_iot_ml_kem.ntt.ntt_binomially_sampled_ring_element + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + re scratch = .ok (re8, scratch4) := by + unfold libcrux_iot_ml_kem.ntt.ntt_binomially_sampled_ring_element + simp [h_step1_eq, h_step2_eq, h_step3_eq, h_step4_eq, + h_step5_eq, h_step6_eq, h_step7_eq, h_step8_eq, + h_i14535_eq, h_i6656_eq, h_i17863_eq, + h_i9984_eq, h_i21191_eq, + h_i13312_eq, h_i24519_eq, + h_i16640_eq, h_i27847_eq] + apply triple_of_ok_l3 h_body + intro i hi j hj + exact h_re8_bd i hi j hj + +/-! ## L3.7 — `ntt_vector_u_spec` + +Composes the seven forward-NTT driver stages plus the terminal +`poly_barrett_reduce`, starting from the already-decompressed bound +`≤ 3328`: + + L3.4(layer=7) → L3.4(layer=6) → L3.4(layer=5) → L3.4(layer=4) + → L3.3_B → L3.2_B → L3.1_B → L6.1 + +Bound cascade (per coefficient): + ≤ 3328 → ≤ 6656 → ≤ 9984 → ≤ 13312 → ≤ 16640 + → ≤ 19968 → ≤ 23296 → ≤ 26624 → ≤ 3328. + +Mirrors `ntt_binomially_sampled_ring_element_spec` (L3.6) above, with +one extra L3.4 step (layer=7) replacing the L3.5 (`ntt_at_layer_7`) +prefix that L3.6 uses. Implements the §13.4 "independent equation +chains" pattern. -/ + +set_option maxHeartbeats 32000000 in +@[spec] +theorem ntt_vector_u_spec + (VECTOR_U_COMPRESSION_FACTOR : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.ntt.ntt_vector_u + VECTOR_U_COMPRESSION_FACTOR + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + re scratch + ⦃ ⇓ p => ⌜ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((p.1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328 ⌝ ⦄ := by + -- ============================================================ + -- Step 1: L3.4(layer=7, zeta_i=0, bnd=3328). re → re1. + -- zeta_i out: 0 + 128 >>> 7 = 0 + 1 = 1. |re1| ≤ 6656. + -- ============================================================ + have h_re_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs + ≤ (3328#usize : Std.Usize).val := by + intro i hi j hj + have hb := h_pre i hi j hj + have : (3328#usize : Std.Usize).val = 3328 := rfl + omega + obtain ⟨⟨zeta1, re1, scratch1⟩, h_step1_eq, h_zeta1_val, h_re1_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_4_plus_spec + (layer := 7#usize) (zeta_i := 0#usize) re scratch 3328#usize + (by decide : 4 ≤ (7#usize : Std.Usize).val ∧ (7#usize : Std.Usize).val ≤ 7) + (by decide : (3328#usize : Std.Usize).val ≤ 8 * 3328) + (by decide : + (0#usize : Std.Usize).val = (1 <<< (7 - (7#usize : Std.Usize).val)) - 1) + h_re_loose) + dsimp only at h_zeta1_val h_re1_bd + have h_zeta1_eq1 : zeta1.val = 1 := by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + have h7 : (7#usize : Std.Usize).val = 7 := rfl + rw [h_zeta1_val, h0, h7]; decide + have h_re1_bd' : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 6656 := by + intro i hi j hj + have hb := h_re1_bd i hi j hj + have : (3328#usize : Std.Usize).val = 3328 := rfl + omega + -- ============================================================ + -- Step 1.5: i ← 2 * 3328 = 6656. + -- ============================================================ + have h_mul1_max : + (2#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (2#usize : Std.Usize).val = 2 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i6656, h_i6656_eq, h_i6656_val⟩ := + usize_mul_ok_eq (2#usize : Std.Usize) (3328#usize : Std.Usize) h_mul1_max + have h_i6656_eq_val : i6656.val = 6656 := by + rw [h_i6656_val]; decide + -- ============================================================ + -- Step 2: L3.4(layer=6, zeta_i=1, bnd=6656). re1 → re2. + -- zeta_i out: 1 + 128 >>> 6 = 1 + 2 = 3. |re2| ≤ 9984. + -- ============================================================ + have h_re1_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re1.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ i6656.val := by + intro i hi j hj + have hb := h_re1_bd' i hi j hj + omega + obtain ⟨⟨zeta2, re2, scratch2⟩, h_step2_eq, h_zeta2_val, h_re2_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_4_plus_spec + (layer := 6#usize) (zeta_i := zeta1) re1 scratch1 i6656 + (by decide : 4 ≤ (6#usize : Std.Usize).val ∧ (6#usize : Std.Usize).val ≤ 7) + (by rw [h_i6656_eq_val]; decide) + (by + have h6 : (6#usize : Std.Usize).val = 6 := rfl + rw [h_zeta1_eq1, h6]; decide) + h_re1_loose) + dsimp only at h_zeta2_val h_re2_bd + have h_zeta2_eq3 : zeta2.val = 3 := by + have h6 : (6#usize : Std.Usize).val = 6 := rfl + rw [h_zeta2_val, h_zeta1_eq1, h6]; decide + have h_re2_bd' : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 9984 := by + intro i hi j hj + have hb := h_re2_bd i hi j hj + have h_iv : i6656.val = 6656 := h_i6656_eq_val + omega + -- ============================================================ + -- Step 2.5: i1 ← 3 * 3328 = 9984. + -- ============================================================ + have h_mul2_max : + (3#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (3#usize : Std.Usize).val = 3 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i9984, h_i9984_eq, h_i9984_val⟩ := + usize_mul_ok_eq (3#usize : Std.Usize) (3328#usize : Std.Usize) h_mul2_max + have h_i9984_eq_val : i9984.val = 9984 := by + rw [h_i9984_val]; decide + -- ============================================================ + -- Step 3: L3.4(layer=5, zeta_i=3, bnd=9984). re2 → re3. + -- zeta_i out: 3 + 128 >>> 5 = 3 + 4 = 7. |re3| ≤ 13312. + -- ============================================================ + have h_re2_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re2.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ i9984.val := by + intro i hi j hj + have hb := h_re2_bd' i hi j hj + omega + obtain ⟨⟨zeta3, re3, scratch3⟩, h_step3_eq, h_zeta3_val, h_re3_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_4_plus_spec + (layer := 5#usize) (zeta_i := zeta2) re2 scratch2 i9984 + (by decide : 4 ≤ (5#usize : Std.Usize).val ∧ (5#usize : Std.Usize).val ≤ 7) + (by rw [h_i9984_eq_val]; decide) + (by + have h5 : (5#usize : Std.Usize).val = 5 := rfl + rw [h_zeta2_eq3, h5]; decide) + h_re2_loose) + dsimp only at h_zeta3_val h_re3_bd + have h_zeta3_eq7 : zeta3.val = 7 := by + have h5 : (5#usize : Std.Usize).val = 5 := rfl + rw [h_zeta3_val, h_zeta2_eq3, h5]; decide + have h_re3_bd' : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re3.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 13312 := by + intro i hi j hj + have hb := h_re3_bd i hi j hj + have h_iv : i9984.val = 9984 := h_i9984_eq_val + omega + -- ============================================================ + -- Step 3.5: i2 ← 4 * 3328 = 13312. + -- ============================================================ + have h_mul3_max : + (4#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (4#usize : Std.Usize).val = 4 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i13312, h_i13312_eq, h_i13312_val⟩ := + usize_mul_ok_eq (4#usize : Std.Usize) (3328#usize : Std.Usize) h_mul3_max + have h_i13312_eq_val : i13312.val = 13312 := by + rw [h_i13312_val]; decide + -- ============================================================ + -- Step 4: L3.4(layer=4, zeta_i=7, bnd=13312). re3 → re4. + -- zeta_i out: 7 + 128 >>> 4 = 7 + 8 = 15. |re4| ≤ 16640. + -- ============================================================ + have h_re3_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re3.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ i13312.val := by + intro i hi j hj + have hb := h_re3_bd' i hi j hj + omega + obtain ⟨⟨zeta4, re4, scratch4⟩, h_step4_eq, h_zeta4_val, h_re4_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_4_plus_spec + (layer := 4#usize) (zeta_i := zeta3) re3 scratch3 i13312 + (by decide : 4 ≤ (4#usize : Std.Usize).val ∧ (4#usize : Std.Usize).val ≤ 7) + (by rw [h_i13312_eq_val]; decide) + (by + have h4 : (4#usize : Std.Usize).val = 4 := rfl + rw [h_zeta3_eq7, h4]; decide) + h_re3_loose) + dsimp only at h_zeta4_val h_re4_bd + have h_zeta4_eq15 : zeta4.val = 15 := by + have h4 : (4#usize : Std.Usize).val = 4 := rfl + rw [h_zeta4_val, h_zeta3_eq7, h4]; decide + have h_re4_bd' : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re4.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 16640 := by + intro i hi j hj + have hb := h_re4_bd i hi j hj + have h_iv : i13312.val = 13312 := h_i13312_eq_val + omega + -- ============================================================ + -- Step 4.5: i3 ← 5 * 3328 = 16640. + -- ============================================================ + have h_mul4_max : + (5#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (5#usize : Std.Usize).val = 5 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i16640, h_i16640_eq, h_i16640_val⟩ := + usize_mul_ok_eq (5#usize : Std.Usize) (3328#usize : Std.Usize) h_mul4_max + have h_i16640_eq_val : i16640.val = 16640 := by + rw [h_i16640_val]; decide + -- ============================================================ + -- Step 5: L3.3_B(zeta_i=15, bnd=16640). re4 → re5. + -- zeta_i out: 31. |re5| ≤ 19968. + -- ============================================================ + obtain ⟨⟨zeta5, re5⟩, h_step5_eq, h_zeta5_val, h_re5_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_3_spec_B + (zeta_i := zeta4) re4 i16640 + (bnd := 16640) (h_bnd := by decide) + (h_zeta := h_zeta4_eq15) + h_re4_bd') + dsimp only at h_zeta5_val h_re5_bd + -- ============================================================ + -- Step 5.5: i4 ← 6 * 3328 = 19968. + -- ============================================================ + have h_mul5_max : + (6#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (6#usize : Std.Usize).val = 6 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i19968, h_i19968_eq, h_i19968_val⟩ := + usize_mul_ok_eq (6#usize : Std.Usize) (3328#usize : Std.Usize) h_mul5_max + have h_i19968_eq_val : i19968.val = 19968 := by + rw [h_i19968_val]; decide + -- ============================================================ + -- Step 6: L3.2_B(zeta_i=31, bnd=19968). re5 → re6. + -- zeta_i out: 63. |re6| ≤ 23296. + -- ============================================================ + obtain ⟨⟨zeta6, re6⟩, h_step6_eq, h_zeta6_val, h_re6_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_2_spec_B + (zeta_i := zeta5) re5 i19968 + (bnd := 19968) (h_bnd := by decide) + (h_zeta := h_zeta5_val) + h_re5_bd) + dsimp only at h_zeta6_val h_re6_bd + -- ============================================================ + -- Step 6.5: i5 ← 7 * 3328 = 23296. + -- ============================================================ + have h_mul6_max : + (7#usize : Std.Usize).val * (3328#usize : Std.Usize).val ≤ Std.Usize.max := by + have : (7#usize : Std.Usize).val = 7 := rfl + have h2 : (3328#usize : Std.Usize).val = 3328 := rfl + rw [this, h2]; scalar_tac + obtain ⟨i23296, h_i23296_eq, h_i23296_val⟩ := + usize_mul_ok_eq (7#usize : Std.Usize) (3328#usize : Std.Usize) h_mul6_max + have h_i23296_eq_val : i23296.val = 23296 := by + rw [h_i23296_val]; decide + -- ============================================================ + -- Step 7: L3.1_B(zeta_i=63, bnd=23296). re6 → re7. + -- zeta_i out: 127. |re7| ≤ 26624. + -- ============================================================ + obtain ⟨⟨zeta7, re7⟩, h_step7_eq, _h_zeta7_val, h_re7_bd⟩ := + triple_exists_ok_l3 (ntt_at_layer_1_spec_B + (zeta_i := zeta6) re6 i23296 + (bnd := 23296) (h_bnd := by decide) + (h_zeta := h_zeta6_val) + h_re6_bd) + dsimp only at h_re7_bd + -- ============================================================ + -- Step 8: L6.1 poly_barrett_reduce. re7 → re8, |re8| ≤ 3328. + -- ============================================================ + have h_re7_loose : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re7.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 32767 := by + intro i hi j hj + have hb := h_re7_bd i hi j hj + omega + obtain ⟨re8, h_step8_eq, h_re8_bd⟩ := + triple_exists_ok_l3 (PolynomialRingElement_poly_barrett_reduce_spec re7 h_re7_loose) + -- ============================================================ + -- Compose: derive the full impl `do`-block equation. + -- ============================================================ + have h_body : + libcrux_iot_ml_kem.ntt.ntt_vector_u + VECTOR_U_COMPRESSION_FACTOR + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + re scratch = .ok (re8, scratch4) := by + unfold libcrux_iot_ml_kem.ntt.ntt_vector_u + simp [h_step1_eq, h_step2_eq, h_step3_eq, h_step4_eq, + h_step5_eq, h_step6_eq, h_step7_eq, h_step8_eq, + h_i6656_eq, h_i9984_eq, h_i13312_eq, + h_i16640_eq, h_i19968_eq, h_i23296_eq] + apply triple_of_ok_l3 h_body + intro i hi j hj + exact h_re8_bd i hi j hj + +end libcrux_iot_ml_kem.Polynomial.NttDrivers \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/NttMultiply.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/NttMultiply.lean new file mode 100644 index 00000000..cb25bbfa --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/NttMultiply.lean @@ -0,0 +1,11529 @@ +/- + # `Polynomial/NttMultiply.lean` — extracted from `FCTargets.lean` §ntt_multiply. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.Polynomial.NttMultiply +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §L2.8 / §L6.3 — NTT-multiply scaffolding. + + Statement skeletons for the NTT-domain multiplication chain that + the L7 matrix-level targets depend on. + + Naming convention (distinguishes vector-level from polynomial-level + since both impl namespaces define `accumulating_ntt_multiply`): + L2.8 base : `accumulating_ntt_multiply_fc` (vector chunk, I32 slice) + L6.3 base : `accumulating_ntt_multiply_poly_fc` (polynomial, I32[256]) + + Helpers introduced here (also sorry-bodied, filled by sub-dispatches): + `ntt_multiply_base_case_post` : Prop predicate captured by L2.8. + Body (the per-pair degree-2 polynomial multiply mod (X²−ζ²) + equation) is filled by L2.8b (`ntt_multiply_base_case_alg`). + `Spec.multiply_ntts_pure` : pure projection of hacspec + `ntt.multiply_ntts`. Body is filled by an M.1 pre-stage + commit before L6.3b dispatches its FC equation. + + Cache variants (`_fill_cache`, `_use_cache`) are deferred to L2.8d + / L6.3c sibling-adaptation dispatches and are NOT locked here. -/ + +/-- Pure projection of `hacspec_ml_kem.ntt.multiply_ntts` (the N=256 + polynomial NTT-domain multiply spec). The `.ok` value of the + hacspec `Result` is the spec polynomial; on `.fail` (unreachable + for canonical inputs) we default to the zero polynomial. + + Used by L6.3 locked POST as the spec-side RHS, anchoring the + impl's I32 accumulator (after L1.10's Mont-reduce) to the hacspec + `multiply_ntts` projection. Composes with the L6.3a per-chunk + decomposition + L2.8's `Spec.chunk_reducing_from_i32_array_pure` + chain. -/ +noncomputable def Spec.multiply_ntts_pure + (p1 p2 : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + match hacspec_ml_kem.ntt.multiply_ntts p1 p2 with + | .ok r => r + | _ => default + +/-- Pure no-accumulate base-case NTT multiply (the "product" part). + + Given Mont-domain lifts of `lhs`, `rhs` (16 lanes each) and 4 + Mont-domain zetas, computes the 16-lane product of the per-pair + degree-2 polynomial multiplies mod (X²−ζ²). Each pair `j ∈ 0..7` + consumes effective zeta `[zeta0, -zeta0, zeta1, -zeta1, zeta2, + -zeta2, zeta3, -zeta3][j]` and produces + `product[2j] = a[2j]·b[2j] + a[2j+1]·b[2j+1]·ζ_j` + `product[2j+1] = a[2j]·b[2j+1] + a[2j+1]·b[2j]`. + + All arithmetic is in `FieldElement` (ZMod 3329). The accumulating + variant `ntt_multiply_base_case_alg` is the pointwise sum of this + product with an initial accumulator (`Spec.chunk_add_pure acc + product`). Separating the two simplifies the per-pair commute + (A.16/A.17/A.18 fire directly on the product) and makes the L7 + bridge to hacspec `multiply_ntts` (non-accumulating) trivial when + the initial accumulator is zero. -/ +noncomputable def Spec.ntt_multiply_pure_no_acc + (lhs_m rhs_m : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (zeta0_m zeta1_m zeta2_m zeta3_m : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let neg := libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure + let add := libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + let mul := libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + let zetas : List hacspec_ml_kem.parameters.FieldElement := + [zeta0_m, neg zeta0_m, zeta1_m, neg zeta1_m, + zeta2_m, neg zeta2_m, zeta3_m, neg zeta3_m] + Std.Array.make 16#usize + ((List.range 16).map (fun k => + let pair_idx := k / 2 + let zeta := zetas[pair_idx]! + let a0 := lhs_m.val[2 * pair_idx]! + let a1 := lhs_m.val[2 * pair_idx + 1]! + let b0 := rhs_m.val[2 * pair_idx]! + let b1 := rhs_m.val[2 * pair_idx + 1]! + if k % 2 = 0 then add (mul a0 b0) (mul (mul a1 b1) zeta) + else add (mul a0 b1) (mul a1 b0))) + (by simp) + +/-! ### §L6.3b — `Spec.multiply_ntts_pure` ↔ chunked `Spec.ntt_multiply_pure_no_acc` + bridge. + + Required by the L7 matrix-level FC theorems (compute_As_plus_e_fc et al.) + to connect the impl-side per-chunk Mont accumulator (which the L6.3 + family produces in `Spec.ntt_multiply_pure_no_acc` form) to the hacspec + `multiply_ntts`-based matrix product spec. The bridge is a SpecPure-side + algebraic identity — no impl side, no Triple, no Mont-domain crossing + beyond what `Spec.zeta_at` already absorbs. -/ + +set_option maxRecDepth 4000 in +set_option maxHeartbeats 16000000 in +/-- : `ntt.ZETAS` reduces to a concrete `.ok` value (since + `parameters.FieldElement.new` is unconditional), and for `i ∈ [64, 128)` + the `i`-th lookup of that value equals `Spec.zeta_at i`. The numeric + fact at each position is the keystone identity + `(ZETAS_TIMES_MONTGOMERY_R[i].val * 169) mod 3329 = ntt.ZETAS[i].val` + (i.e. impl-side Mont zeta times R⁻¹ = canonical zeta). -/ +theorem hacspec_ZETAS_ok_and_zeta_at : + ∃ zs : Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 128#usize, + hacspec_ml_kem.ntt.ZETAS = .ok zs + ∧ (∀ i : Nat, 64 ≤ i → i < 128 → Spec.zeta_at i = zs.val[i]!) := by + unfold hacspec_ml_kem.ntt.ZETAS + refine ⟨_, rfl, ?_⟩ + intro i h_lo h_hi + interval_cases i <;> + · show lift_fe_mont _ = _ + unfold lift_fe_mont i16_to_spec_fe_mont feOfZMod + simp only [libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R] + rfl + +/-! ### §L6.3b — .2/3/4: per-lane reduction + `from_fn_pure_eq` lift + + chunked assembly. + + The chain below realises `Spec.multiply_ntts_pure_eq_chunked_no_acc` (the + canonical bridge between hacspec `ntt.multiply_ntts` and the impl-side + chunked `Spec.ntt_multiply_pure_no_acc` form, required by every L7 + matrix-level FC theorem). + + Architecture mirrors `LibcruxIotSha3/Sponge/` (the + `sponge_squeeze_byte_eq` yardstick): a per-call_mut `_eq_pure` Result + equation drives `libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq` to lift the entire 256-lane + `multiply_ntts` to a pure-list, then `Subtype.ext` + per-lane reduction + closes the chunked-decomposition equality. + + Helpers live inside `HelpersFC` namespace for hygiene; only the final + theorem is re-exported. -/ + +namespace HelpersFC + +/-- `feOfZMod` always produces a canonical FE (since `z.val < 3329`). The + `BitVec.ofNat 16` lift is in-range modulo `2^16 = 65536 > 3329`. -/ +theorem Canonical_feOfZMod (z : ZMod 3329) : + Spec.Pure.Canonical (feOfZMod z) := by + unfold Spec.Pure.Canonical feOfZMod hacspec_ml_kem.parameters.FIELD_MODULUS + have h_lt : z.val < 3329 := ZMod.val_lt z + show (BitVec.ofNat 16 z.val).toNat < 3329 + rw [BitVec.toNat_ofNat, Nat.mod_eq_of_lt] + · exact h_lt + · exact Nat.lt_of_lt_of_le h_lt (by decide) + +/-- Zeta projections (`Spec.zeta_at i = feOfZMod _`) are always canonical. -/ +theorem Canonical_zeta_at (i : Nat) : + Spec.Pure.Canonical (Spec.zeta_at i) := by + unfold Spec.zeta_at lift_fe_mont + exact Canonical_feOfZMod _ + +/-- `Slice.index_usize` reduces to `.ok (s.val[i.val]!)` for in-bounds index. -/ +theorem slice_index_usize_eq_ok' {α} [Inhabited α] + (s : Aeneas.Std.Slice α) (i : Std.Usize) (h : i.val < s.val.length) : + Aeneas.Std.Slice.index_usize s i = .ok (s.val[i.val]!) := by + unfold Aeneas.Std.Slice.index_usize + have h_eq : s[i]? = s.val[i.val]? := rfl + rw [h_eq, List.getElem?_eq_getElem h] + show Aeneas.Std.Result.ok _ = Aeneas.Std.Result.ok _ + congr + rw [List.getElem!_eq_getElem?_getD, List.getElem?_eq_getElem h]; rfl + +/-- `Aeneas.Std.Array.index_usize` reduces to `.ok (a.val[i.val]!)`. -/ +theorem array_index_usize_eq_ok' {α n} [Inhabited α] + (a : Aeneas.Std.Array α n) (i : Std.Usize) (h : i.val < a.val.length) : + Aeneas.Std.Array.index_usize a i = .ok (a.val[i.val]!) := by + unfold Aeneas.Std.Array.index_usize + have h_eq : a[i]? = a.val[i.val]? := rfl + rw [h_eq, List.getElem?_eq_getElem h] + show Aeneas.Std.Result.ok _ = Aeneas.Std.Result.ok _ + congr + rw [List.getElem!_eq_getElem?_getD, List.getElem?_eq_getElem h]; rfl + +/-- `base_case_multiply_even` reduces to `.ok (a0*b0 + (a1*b1)*ζ)` via the + three `mul_eq_ok`s and the final `add_eq_ok`. -/ +theorem base_case_multiply_even_eq + (a0 a1 b0 b1 zeta : hacspec_ml_kem.parameters.FieldElement) : + hacspec_ml_kem.ntt.base_case_multiply_even a0 a1 b0 b1 zeta = .ok + (Spec.Pure.FieldElement.add_pure (Spec.Pure.FieldElement.mul_pure a0 b0) + (Spec.Pure.FieldElement.mul_pure + (Spec.Pure.FieldElement.mul_pure a1 b1) zeta)) := by + unfold hacspec_ml_kem.ntt.base_case_multiply_even + rw [Spec.Pure.FieldElement.mul_eq_ok]; simp only [bind_tc_ok] + rw [Spec.Pure.FieldElement.mul_eq_ok]; simp only [bind_tc_ok] + rw [Spec.Pure.FieldElement.mul_eq_ok]; simp only [bind_tc_ok] + rw [Spec.Pure.FieldElement.add_eq_ok] + +/-- `base_case_multiply_odd` reduces to `.ok (a0*b1 + a1*b0)` via two + `mul_eq_ok`s and the final `add_eq_ok`. -/ +theorem base_case_multiply_odd_eq + (a0 a1 b0 b1 : hacspec_ml_kem.parameters.FieldElement) : + hacspec_ml_kem.ntt.base_case_multiply_odd a0 a1 b0 b1 = .ok + (Spec.Pure.FieldElement.add_pure (Spec.Pure.FieldElement.mul_pure a0 b1) + (Spec.Pure.FieldElement.mul_pure a1 b0)) := by + unfold hacspec_ml_kem.ntt.base_case_multiply_odd + rw [Spec.Pure.FieldElement.mul_eq_ok]; simp only [bind_tc_ok] + rw [Spec.Pure.FieldElement.mul_eq_ok]; simp only [bind_tc_ok] + rw [Spec.Pure.FieldElement.add_eq_ok] + +/-- Pure lane value of `multiply_ntts` at index `i ∈ [0, 256)`. + + Mirrors the impl `ntt_multiply_n_at` body: looks up zeta from the slice at + `i/4` (negated when `i % 4 ≥ 2`), then dispatches to + `base_case_multiply_{even,odd}` per `i % 2`. The pure form replaces the + `Result`-monad ops with their `_pure` projections (`add_pure`, `mul_pure`, + `neg_pure`). The zeta is taken from `Spec.zeta_at (64 + i/4)` to match + the impl's `zetas[64..128]` slice access. -/ +noncomputable def multiply_ntts_lane_pure + (p1 p2 : Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (i : Nat) : hacspec_ml_kem.parameters.FieldElement := + let group := i / 4 + let i1 := i % 4 + let zeta_base := Spec.zeta_at (64 + group) + let zeta := if i1 < 2 then zeta_base + else Spec.Pure.FieldElement.neg_pure zeta_base + if i % 2 = 0 then + let a0 := p1.val[i]! + let a1 := p1.val[i+1]! + let b0 := p2.val[i]! + let b1 := p2.val[i+1]! + Spec.Pure.FieldElement.add_pure (Spec.Pure.FieldElement.mul_pure a0 b0) + (Spec.Pure.FieldElement.mul_pure + (Spec.Pure.FieldElement.mul_pure a1 b1) zeta) + else + let a0 := p1.val[i-1]! + let a1 := p1.val[i]! + let b0 := p2.val[i-1]! + let b1 := p2.val[i]! + Spec.Pure.FieldElement.add_pure (Spec.Pure.FieldElement.mul_pure a0 b1) + (Spec.Pure.FieldElement.mul_pure a1 b0) + +set_option maxHeartbeats 16000000 in +/-- **Per-lane reduction of `ntt.ntt_multiply_n_at`.** + + For any slice `s` of length 64 satisfying `s.val[k]! = Spec.zeta_at + (64 + k)` for `k < 64`, the hacspec body `ntt.ntt_multiply_n_at p1 p2 s + i` succeeds with `multiply_ntts_lane_pure p1 p2 i.val`. Drives the + `from_fn_pure_eq` lift in .3. -/ +theorem ntt_multiply_n_at_eq_pure + (p1 p2 : Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (s : Aeneas.Std.Slice hacspec_ml_kem.parameters.FieldElement) + (h_slen : s.val.length = 64) + (h_zeta_eq : ∀ k : Nat, k < 64 → s.val[k]! = Spec.zeta_at (64 + k)) + (i : Std.Usize) (hi : i.val < 256) : + hacspec_ml_kem.ntt.ntt_multiply_n_at p1 p2 s i + = .ok (multiply_ntts_lane_pure p1 p2 i.val) := by + unfold hacspec_ml_kem.ntt.ntt_multiply_n_at + -- Step 1: group ← i / 4#usize. + obtain ⟨group, h_g_eq, h_g_v, _⟩ := + Std.UScalar.div_bv_spec i (show ((4#usize : Std.Usize)).val ≠ 0 by decide) + rw [h_g_eq]; simp only [bind_tc_ok] + have h_g_val : group.val = i.val / 4 := by rw [h_g_v]; rfl + have h_g_lt : group.val < 64 := by rw [h_g_val]; omega + have h_g_lt_slen : group.val < s.val.length := by rw [h_slen]; exact h_g_lt + -- Step 2: i1 ← i % 4#usize. + obtain ⟨i1, h_i1_eq, h_i1_v, _⟩ := + Std.WP.spec_imp_exists (Std.UScalar.rem_bv_spec i + (show ((4#usize : Std.Usize)).val ≠ 0 by decide)) + rw [h_i1_eq]; simp only [bind_tc_ok] + have h_i1_val : i1.val = i.val % 4 := by rw [h_i1_v]; rfl + -- Slice index lookup + zeta correspondence + canonicity. + have h_slice_idx_ok : + Aeneas.Std.Slice.index_usize s group = .ok (s.val[group.val]!) := + slice_index_usize_eq_ok' s group h_g_lt_slen + have h_zeta_val : s.val[group.val]! = Spec.zeta_at (64 + group.val) := + h_zeta_eq group.val h_g_lt + have h_canon_zeta : Spec.Pure.Canonical (s.val[group.val]!) := by + rw [h_zeta_val]; exact Canonical_zeta_at _ + -- Step 3: zeta. Collapse the zeta branch into a uniform tail. + set zeta_pure : hacspec_ml_kem.parameters.FieldElement := + if i.val % 4 < 2 then Spec.zeta_at (64 + i.val / 4) + else Spec.Pure.FieldElement.neg_pure (Spec.zeta_at (64 + i.val / 4)) + with h_zeta_pure_def + have h_zeta_result : + (do let z ← (if i1 < 2#usize then Aeneas.Std.Slice.index_usize s group + else do let fe ← Aeneas.Std.Slice.index_usize s group + hacspec_ml_kem.parameters.FieldElement.neg fe); + (do let i2 ← i % 2#usize; + if i2 = 0#usize + then do let fe ← Aeneas.Std.Array.index_usize p1 i; + let i3 ← i + 1#usize; + let fe1 ← Aeneas.Std.Array.index_usize p1 i3; + let fe2 ← Aeneas.Std.Array.index_usize p2 i; + let fe3 ← Aeneas.Std.Array.index_usize p2 i3; + hacspec_ml_kem.ntt.base_case_multiply_even fe fe1 fe2 fe3 z + else do let i3 ← i - 1#usize; + let fe ← Aeneas.Std.Array.index_usize p1 i3; + let fe1 ← Aeneas.Std.Array.index_usize p1 i; + let fe2 ← Aeneas.Std.Array.index_usize p2 i3; + let fe3 ← Aeneas.Std.Array.index_usize p2 i; + hacspec_ml_kem.ntt.base_case_multiply_odd fe fe1 fe2 fe3)) = + (do let i2 ← i % 2#usize; + if i2 = 0#usize + then do let fe ← Aeneas.Std.Array.index_usize p1 i; + let i3 ← i + 1#usize; + let fe1 ← Aeneas.Std.Array.index_usize p1 i3; + let fe2 ← Aeneas.Std.Array.index_usize p2 i; + let fe3 ← Aeneas.Std.Array.index_usize p2 i3; + hacspec_ml_kem.ntt.base_case_multiply_even fe fe1 fe2 fe3 zeta_pure + else do let i3 ← i - 1#usize; + let fe ← Aeneas.Std.Array.index_usize p1 i3; + let fe1 ← Aeneas.Std.Array.index_usize p1 i; + let fe2 ← Aeneas.Std.Array.index_usize p2 i3; + let fe3 ← Aeneas.Std.Array.index_usize p2 i; + hacspec_ml_kem.ntt.base_case_multiply_odd fe fe1 fe2 fe3) := by + rcases (Nat.lt_or_ge i1.val 2) with h_i1_lt | h_i1_ge + · rw [if_pos (show i1 < 2#usize from h_i1_lt)] + rw [h_slice_idx_ok]; simp only [bind_tc_ok] + rw [h_zeta_val] + simp only [h_zeta_pure_def, + show i.val % 4 < 2 from h_i1_val ▸ h_i1_lt, if_true] + rw [h_g_val] + · rw [if_neg (show ¬ i1 < 2#usize from by show ¬ i1.val < 2; omega)] + rw [h_slice_idx_ok]; simp only [bind_tc_ok] + rw [Spec.Pure.FieldElement.neg_eq_ok _ h_canon_zeta] + simp only [bind_tc_ok] + rw [h_zeta_val] + simp only [h_zeta_pure_def, + show ¬ i.val % 4 < 2 from by + have : i1.val = i.val % 4 := h_i1_val; omega, if_false] + rw [h_g_val] + rw [h_zeta_result] + -- Step 4: i2 ← i % 2#usize. + obtain ⟨i2, h_i2_eq, h_i2_v, _⟩ := + Std.WP.spec_imp_exists (Std.UScalar.rem_bv_spec i + (show ((2#usize : Std.Usize)).val ≠ 0 by decide)) + rw [h_i2_eq]; simp only [bind_tc_ok] + have h_i2_val : i2.val = i.val % 2 := by rw [h_i2_v]; rfl + -- Array bounds (p1.val.length = p2.val.length = 256). + have h_p1_len : p1.val.length = 256 := p1.property + have h_p2_len : p2.val.length = 256 := p2.property + have h_i_lt_p1 : i.val < p1.val.length := by rw [h_p1_len]; exact hi + have h_i_lt_p2 : i.val < p2.val.length := by rw [h_p2_len]; exact hi + have h_i2_lt_2 : i2.val < 2 := by + rw [h_i2_val]; exact Nat.mod_lt _ (by decide) + -- Step 5: branch on i2.val = 0 vs 1. + rcases (show i2.val = 0 ∨ i2.val = 1 from by omega) with h_i2_0 | h_i2_1 + · -- Even branch. + have h_i2_eq_0 : i2 = 0#usize := + Std.UScalar.eq_of_val_eq (by rw [h_i2_0]; rfl) + rw [if_pos h_i2_eq_0] + obtain ⟨i3, h_i3_eq, h_i3_v, _⟩ := + Std.WP.spec_imp_exists + (Std.UScalar.add_bv_spec (x := i) (y := 1#usize) (by scalar_tac)) + have h_i3_val : i3.val = i.val + 1 := by rw [h_i3_v]; rfl + have h_i3_lt_p1 : i3.val < p1.val.length := by + rw [h_p1_len, h_i3_val]; omega + have h_i3_lt_p2 : i3.val < p2.val.length := by + rw [h_p2_len, h_i3_val]; omega + rw [array_index_usize_eq_ok' p1 i h_i_lt_p1]; simp only [bind_tc_ok] + rw [h_i3_eq]; simp only [bind_tc_ok] + rw [array_index_usize_eq_ok' p1 i3 h_i3_lt_p1]; simp only [bind_tc_ok] + rw [array_index_usize_eq_ok' p2 i h_i_lt_p2]; simp only [bind_tc_ok] + rw [array_index_usize_eq_ok' p2 i3 h_i3_lt_p2]; simp only [bind_tc_ok] + rw [base_case_multiply_even_eq] + congr 1 + unfold multiply_ntts_lane_pure + have h_imod2 : i.val % 2 = 0 := by rw [← h_i2_val]; exact h_i2_0 + simp only [h_imod2, if_true, h_i3_val, h_zeta_pure_def] + · -- Odd branch. + have h_i2_ne_0 : ¬ (i2 = 0#usize) := by + intro heq + have h_zero : i2.val = 0 := by rw [heq]; rfl + omega + rw [if_neg h_i2_ne_0] + have h_i_ge_1 : 1 ≤ i.val := by + have h_imod2 : i.val % 2 = 1 := by rw [← h_i2_val]; exact h_i2_1 + omega + obtain ⟨i3, h_i3_eq, h_i3_v, _⟩ := + Std.WP.spec_imp_exists + (Std.UScalar.sub_bv_spec (x := i) (y := 1#usize) (by scalar_tac)) + have h_i3_val : i3.val = i.val - 1 := by rw [h_i3_v]; rfl + have h_i3_lt_p1 : i3.val < p1.val.length := by + rw [h_p1_len, h_i3_val]; omega + have h_i3_lt_p2 : i3.val < p2.val.length := by + rw [h_p2_len, h_i3_val]; omega + rw [h_i3_eq]; simp only [bind_tc_ok] + rw [array_index_usize_eq_ok' p1 i3 h_i3_lt_p1]; simp only [bind_tc_ok] + rw [array_index_usize_eq_ok' p1 i h_i_lt_p1]; simp only [bind_tc_ok] + rw [array_index_usize_eq_ok' p2 i3 h_i3_lt_p2]; simp only [bind_tc_ok] + rw [array_index_usize_eq_ok' p2 i h_i_lt_p2]; simp only [bind_tc_ok] + rw [base_case_multiply_odd_eq] + congr 1 + unfold multiply_ntts_lane_pure + have h_imod2 : i.val % 2 = 1 := by rw [← h_i2_val]; exact h_i2_1 + simp only [h_imod2, Nat.one_ne_zero, if_false, h_i3_val] + +/-! ### §L6.3b — .3: lift `multiply_ntts` to a pure 256-list. -/ + +/-- The 64-position slice extracted from a length-128 array has length 64. -/ +lemma slice_length_64 + (zs : Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 128#usize) : + (List.slice 64 128 zs.val).length = 64 := by + have h : zs.val.length = 128 := zs.property + unfold List.slice + simp [List.length_take, h] + +/-- The hacspec slice-by-range extraction `zs[64..128]` reduces to the + explicit `List.slice 64 128 zs.val` slice. Drives the slice-step in + .3's reduction of `ntt.multiply_ntts`. -/ +lemma slice_zetas_succeeds + (zs : Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 128#usize) : + core.Array.Insts.CoreOpsIndexIndex.index + (core.Slice.Insts.CoreOpsIndexIndex + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + hacspec_ml_kem.parameters.FieldElement)) zs + { start := 64#usize, «end» := 128#usize } + = .ok (⟨List.slice 64 128 zs.val, by + rw [slice_length_64]; scalar_tac⟩ : + Aeneas.Std.Slice hacspec_ml_kem.parameters.FieldElement) := by + unfold core.Array.Insts.CoreOpsIndexIndex.index + core.slice.index.Slice.index + core.Slice.Insts.CoreOpsIndexIndex + core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + show core.slice.index.SliceIndexRangeUsizeSlice.index + (core.cmRangeUsizeToAeneas _) zs.to_slice = _ + unfold core.slice.index.SliceIndexRangeUsizeSlice.index + core.cmRangeUsizeToAeneas + have h_alen : zs.val.length = 128 := zs.property + have h_cond : (64#usize : Std.Usize) ≤ (128#usize : Std.Usize) ∧ + (128#usize : Std.Usize).val ≤ zs.to_slice.val.length := by + refine ⟨by show (64 : Nat) ≤ 128; decide, by + show 128 ≤ zs.to_slice.val.length + show 128 ≤ zs.val.length; omega⟩ + rw [if_pos h_cond] + rfl + +/-- `List.slice a b l [k]! = l[a + k]!` when `a ≤ b`, `b ≤ l.length`, + and `k < b - a`. -/ +lemma slice_getElem_at {α} [Inhabited α] + (l : List α) (a b : Nat) (h_le_a : a ≤ b) (h_le_b : b ≤ l.length) + (k : Nat) (hk : k < b - a) : + (List.slice a b l)[k]! = l[a + k]! := by + unfold List.slice + have h_ak_lt : a + k < l.length := by omega + have h_drop_len : (l.drop a).length = l.length - a := by simp + have h_k_lt_drop : k < (l.drop a).length := by rw [h_drop_len]; omega + have h_take_idx : + ((l.drop a).take (b - a))[k]? = (l.drop a)[k]? := by + rw [List.getElem?_take, if_pos hk] + have h_drop_idx : (l.drop a)[k]? = l[a + k]? := by + rw [List.getElem?_drop] + rw [List.getElem!_eq_getElem?_getD, List.getElem!_eq_getElem?_getD, + h_take_idx, h_drop_idx] + +/-- `BitVec.ofNat _ k` round-trips through `Usize.val` when `k < 256`. -/ +lemma usize_ofNat_val_eq_self_of_lt_256 (k : Nat) (h : k < 256) : + (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat System.Platform.numBits k).toNat = k + rw [BitVec.toNat_ofNat] + apply Nat.mod_eq_of_lt + have h_max : k ≤ Std.Usize.max := by scalar_tac + have h_max_def : Std.Usize.max + 1 = 2 ^ System.Platform.numBits := by scalar_tac + omega + +set_option maxHeartbeats 4000000 in +/-- **`hacspec_ml_kem.ntt.multiply_ntts p1 p2` reduces to the pure 256-lane + array.** + + Composes .1 (ZETAS = .ok zs + zeta correspondence), the + slice-extraction reduction, and .2's per-lane reduction via + `libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq`. The result is the pure FE-arithmetic list + `(List.range 256).map (multiply_ntts_lane_pure p1 p2)`. -/ +-- Public (exported for L7.4 `compute_message_acc_bridge`): per-multiply reduction +-- of the hacspec `ntt.multiply_ntts` to its pure-lane array form. Visibility-only +-- change (proof/statement unchanged). +theorem multiply_ntts_eq_pure_array + (p1 p2 : Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + hacspec_ml_kem.ntt.multiply_ntts p1 p2 + = .ok (⟨(List.range 256).map (multiply_ntts_lane_pure p1 p2), + by simp [List.length_map, List.length_range]⟩ : + Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) := by + unfold hacspec_ml_kem.ntt.multiply_ntts + -- Step 1: ntt.ZETAS = .ok zs. + obtain ⟨zs, h_zetas_eq, h_zeta_at⟩ := hacspec_ZETAS_ok_and_zeta_at + rw [h_zetas_eq]; simp only [bind_tc_ok] + -- Step 2: slice extraction. + rw [slice_zetas_succeeds]; simp only [bind_tc_ok] + -- Step 3: ntt.ntt_multiply_n p1 p2 s ⇒ parameters.createi 256 inst (p1, p2, s) + -- ⇒ core.array.from_fn 256 inst.FnMutInst (p1, p2, s). + unfold hacspec_ml_kem.ntt.ntt_multiply_n + hacspec_ml_kem.parameters.createi + -- Slice / properties. + set s : Aeneas.Std.Slice hacspec_ml_kem.parameters.FieldElement := + ⟨List.slice 64 128 zs.val, by rw [slice_length_64]; scalar_tac⟩ with h_s_def + have h_slen : s.val.length = 64 := slice_length_64 zs + have h_zeta_eq_slice : ∀ k : Nat, k < 64 → s.val[k]! = Spec.zeta_at (64 + k) := by + intro k hk + show (List.slice 64 128 zs.val)[k]! = _ + have h_zlen : zs.val.length = 128 := zs.property + rw [slice_getElem_at zs.val 64 128 (by omega) (by omega) k (by omega)] + exact (h_zeta_at (64 + k) (by omega) (by omega)).symm + -- Set f and build the per-call_mut equation. + set f : Nat → hacspec_ml_kem.parameters.FieldElement := + multiply_ntts_lane_pure p1 p2 with h_f_def + have h_call_mut_eq : ∀ k : Nat, k < (256#usize : Std.Usize).val → + ((hacspec_ml_kem.ntt.ntt_multiply_n.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + 256#usize).FnMutInst).call_mut (p1, p2, s) ⟨BitVec.ofNat _ k⟩ + = .ok (f k, (p1, p2, s)) := by + intro k hk + have hk' : k < 256 := hk + have h_k_val : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := + usize_ofNat_val_eq_self_of_lt_256 k hk' + show (do let fe ← hacspec_ml_kem.ntt.ntt_multiply_n_at p1 p2 s + (⟨BitVec.ofNat _ k⟩ : Std.Usize); + .ok (fe, (p1, p2, s))) = _ + have h_lane := ntt_multiply_n_at_eq_pure p1 p2 s h_slen h_zeta_eq_slice + (⟨BitVec.ofNat _ k⟩ : Std.Usize) (by rw [h_k_val]; exact hk') + rw [h_lane]; simp only [bind_tc_ok] + rw [h_k_val] + rfl + -- Apply from_fn_pure_eq. + have h_from_fn := libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq 256#usize + (hacspec_ml_kem.ntt.ntt_multiply_n.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + 256#usize).FnMutInst + (p1, p2, s) f h_call_mut_eq + exact h_from_fn + +/-! ### §L6.3b — .4: chunked assembly + final theorem. -/ + +set_option maxHeartbeats 8000000 in +/-- **Per-lane equality, `j/ℓ` form.** + + For `j < 16`, `ℓ < 16`, the flat lane value + `multiply_ntts_lane_pure p1 p2 (16 * j + ℓ)` equals the `ℓ`-th + lane of the per-chunk product + `Spec.ntt_multiply_pure_no_acc (chunk_at p1 j) (chunk_at p2 j) + ζ_{4j..4j+3}`. Closed via `interval_cases ℓ` (16 cases), each + `rfl` after unfolding `multiply_ntts_lane_pure`, + `Spec.ntt_multiply_pure_no_acc`, and `Spec.chunk_at`. -/ +theorem multiply_ntts_lane_pure_eq_chunked_aux + (p1 p2 : Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (j ℓ : Nat) (hj : j < 16) (hℓ : ℓ < 16) : + multiply_ntts_lane_pure p1 p2 (16 * j + ℓ) = + (Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at p1 j) (Spec.chunk_at p2 j) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]! := by + unfold multiply_ntts_lane_pure + have h_div : (16 * j + ℓ) / 4 = 4 * j + ℓ / 4 := by omega + have h_mod4 : (16 * j + ℓ) % 4 = ℓ % 4 := by omega + have h_mod2 : (16 * j + ℓ) % 2 = ℓ % 2 := by omega + rw [h_div, h_mod4, h_mod2] + unfold Spec.ntt_multiply_pure_no_acc + -- Use `conv_rhs` to scope the index-reduction to the RHS so it doesn't + -- accidentally target an LHS `_[m]!` first. After `unfold`, the RHS has + -- the (List.range 16)-map structure wrapped in `Std.Array.make`/`↑`; + -- the `show` brings the outer projection inline so `rw` can match. + conv_rhs => + rw [show ∀ (l : List _) (h : l.length = (16#usize : Std.Usize).val) (k : Nat), + (↑(Std.Array.make 16#usize l h) : List _)[k]! = l[k]! from fun _ _ _ => rfl, + List.getElem!_eq_getElem?_getD, List.getElem?_map, List.getElem?_range hℓ] + unfold Spec.chunk_at + interval_cases ℓ <;> rfl + +/-- **Per-lane equality, `i` form (wrapper around `_aux`).** -/ +theorem multiply_ntts_lane_pure_eq_chunked + (p1 p2 : Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (i : Nat) (hi : i < 256) : + multiply_ntts_lane_pure p1 p2 i = + (Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at p1 (i / 16)) (Spec.chunk_at p2 (i / 16)) + (Spec.zeta_at (64 + 4 * (i / 16))) + (Spec.zeta_at (64 + 4 * (i / 16) + 1)) + (Spec.zeta_at (64 + 4 * (i / 16) + 2)) + (Spec.zeta_at (64 + 4 * (i / 16) + 3))).val[i % 16]! := by + have h_i : i = 16 * (i / 16) + (i % 16) := by omega + conv_lhs => rw [h_i] + exact multiply_ntts_lane_pure_eq_chunked_aux p1 p2 (i / 16) (i % 16) + (by omega) (Nat.mod_lt _ (by decide)) + +end HelpersFC + +set_option maxHeartbeats 4000000 in +/-- **§L6.3b bridge: hacspec `multiply_ntts` ↔ chunked `ntt_multiply_pure_no_acc`.** + + Connects the spec-side projection of hacspec `ntt.multiply_ntts` + (canonical `Spec.multiply_ntts_pure`) to the impl-side per-chunk + Mont-domain product form `Spec.ntt_multiply_pure_no_acc` aggregated + via `Spec.flatten_chunks` over the 16 chunks. + + This is the bridge required by every L7 matrix-level FC theorem + (`compute_As_plus_e_fc`, `compute_vector_u_fc`, + `compute_ring_element_v_fc`, `compute_message_fc`): the impl + accumulator at row `(i, k)` produces `Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont row[k]) (lift_chunk_mont t[k]) Spec.zeta_at(64+4j..)` + per chunk `j`, and this theorem allows the matrix Triple to collapse + that decomposition into the hacspec `multiply_ntts` form used by + `Spec.compute_As_plus_e` and friends. + + Proof composes: + - .1 (`hacspec_ZETAS_ok_and_zeta_at`): zetas at [64..128) + correspondence. + - .3 (`HelpersFC.multiply_ntts_eq_pure_array`): lifts + `ntt.multiply_ntts p1 p2` to `.ok ⟨pure-list, _⟩`. + - .4 (`HelpersFC.multiply_ntts_lane_pure_eq_chunked`): per-lane + equality. + - Array extensionality (`Subtype.ext` + `List.map_congr_left`). -/ +theorem Spec.multiply_ntts_pure_eq_chunked_no_acc + (p1 p2 : Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Spec.multiply_ntts_pure p1 p2 = + Spec.flatten_chunks + ⟨(List.range 16).map (fun j => + Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at p1 j) (Spec.chunk_at p2 j) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))), + by simp⟩ := by + unfold Spec.multiply_ntts_pure + rw [HelpersFC.multiply_ntts_eq_pure_array] + -- Reduce `match .ok r with .ok r => r | _ => default` to `r` so the Array constructor + -- on the LHS aligns with the Spec.flatten_chunks Array on the RHS. + show (⟨(List.range 256).map (HelpersFC.multiply_ntts_lane_pure p1 p2), + by simp [List.length_map, List.length_range]⟩ : + Aeneas.Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) = _ + apply Subtype.ext + -- Goal: pure_list = (Spec.flatten_chunks ⟨chunks_list, _⟩).val + -- Reduce: ↑⟨L, _⟩ = L by `Subtype.coe_mk`-rfl, and `(Spec.flatten_chunks ⟨L, h⟩).val = + -- (List.range 256).map (fun j => ⟨L, h⟩.val[j/16]!.val[j%16]!) = (List.range 256).map + -- (fun j => L[j/16]!.val[j%16]!)` by unfolding Spec.flatten_chunks and Std.Array.make. + show (List.range 256).map (HelpersFC.multiply_ntts_lane_pure p1 p2) = + (List.range 256).map (fun j => + ((List.range 16).map (fun j' => + Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at p1 j') (Spec.chunk_at p2 j') + (Spec.zeta_at (64 + 4 * j')) + (Spec.zeta_at (64 + 4 * j' + 1)) + (Spec.zeta_at (64 + 4 * j' + 2)) + (Spec.zeta_at (64 + 4 * j' + 3))))[j / 16]!.val[j % 16]!) + apply List.map_congr_left + intro i hi + have h_i_lt : i < 256 := List.mem_range.mp hi + have hi_div_lt : i / 16 < 16 := by omega + -- Reduce the chunks-list lookup at index `i / 16` to the explicit + -- `ntt_multiply_pure_no_acc` value, via List.getElem? expansion of the + -- inner `[i / 16]!`. We do this via a one-shot `have` lemma to scope the + -- rewrites to the inner index only (leaving the outer `[i % 16]!` intact). + have h_chunks_at : ((List.range 16).map (fun j' => + Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at p1 j') (Spec.chunk_at p2 j') + (Spec.zeta_at (64 + 4 * j')) + (Spec.zeta_at (64 + 4 * j' + 1)) + (Spec.zeta_at (64 + 4 * j' + 2)) + (Spec.zeta_at (64 + 4 * j' + 3))))[i / 16]! = + Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at p1 (i / 16)) (Spec.chunk_at p2 (i / 16)) + (Spec.zeta_at (64 + 4 * (i / 16))) + (Spec.zeta_at (64 + 4 * (i / 16) + 1)) + (Spec.zeta_at (64 + 4 * (i / 16) + 2)) + (Spec.zeta_at (64 + 4 * (i / 16) + 3)) := by + rw [List.getElem!_eq_getElem?_getD, List.getElem?_map, + List.getElem?_range hi_div_lt]; rfl + show HelpersFC.multiply_ntts_lane_pure p1 p2 i = + ((List.range 16).map (fun j' => + Spec.ntt_multiply_pure_no_acc + (Spec.chunk_at p1 j') (Spec.chunk_at p2 j') + (Spec.zeta_at (64 + 4 * j')) + (Spec.zeta_at (64 + 4 * j' + 1)) + (Spec.zeta_at (64 + 4 * j' + 2)) + (Spec.zeta_at (64 + 4 * j' + 3))))[i / 16]!.val[i % 16]! + rw [h_chunks_at] + exact HelpersFC.multiply_ntts_lane_pure_eq_chunked p1 p2 i h_i_lt + +/-- Accumulating base-case NTT multiply: pointwise sum of the initial + accumulator with the no-acc product. Defined as + `chunk_add_pure acc (ntt_multiply_pure_no_acc ...)`. The L2.8 POST + anchors against this; downstream provers reduce to the product + + a single trivial additive step. -/ +noncomputable def ntt_multiply_base_case_alg + (lhs_m rhs_m : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (zeta0_m zeta1_m zeta2_m zeta3_m : hacspec_ml_kem.parameters.FieldElement) + (acc_m : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Spec.chunk_add_pure acc_m + (Spec.ntt_multiply_pure_no_acc lhs_m rhs_m + zeta0_m zeta1_m zeta2_m zeta3_m) + +/-- Algebraic POST predicate for the L2.8 vector-level base-case NTT + multiply. Relates the resulting I32 accumulator slice `r` to the + inputs (`lhs`, `rhs`, 4 zetas, initial accumulator `out`) per the + per-pair degree-2 polynomial multiply equation mod (X²−ζ²). + + The impl chains 8 calls to `accumulating_ntt_multiply_binomials` + with effective zetas `[zeta0, -zeta0, zeta1, -zeta1, zeta2, -zeta2, + zeta3, -zeta3]` across pairs `(out[2k], out[2k+1])` for k = 0..7. + + Body uses `Spec.chunk_reducing_from_i32_array_pure` (per-lane + Montgomery reduction) to lift the I32 accumulator to Mont-domain + FE-array, then compares with `ntt_multiply_base_case_alg` applied + to the Mont-domain lifts of inputs. Mirrors the L1.10 + `lift_poly_mont = Spec.poly_reducing_from_i32_array_pure` idiom. -/ +noncomputable def ntt_multiply_base_case_post + (lhs rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 zeta2 zeta3 : Std.I16) + (out r : Aeneas.Std.Slice Std.I32) : Prop := + Spec.chunk_reducing_from_i32_array_pure r = + ntt_multiply_base_case_alg + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3) + (Spec.chunk_reducing_from_i32_array_pure out) + +/-! ### L2.8c — helper Triples and bridge lemmas. + + Per-pair binomial Triple + ZMod-side FE-equation closer. The + per-pair Triple isolates one `accumulating_ntt_multiply_binomials` + call; L2.8c chains 8 of these with alternating-sign zetas. -/ + +/-- I32 → ZMod 3329 cast bridge for sign-extending I16 to I32. + `(as_i32 x : I32).val = x.val` since I16 ⊆ I32. -/ +theorem L2_8c.as_i32_val_eq (x : Std.I16) : + libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 x + = .ok ((Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 x : Std.I32)) := by + unfold libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 + unfold libcrux_secrets.traits.Declassify.Blanket.declassify + unfold libcrux_secrets.traits.Classify.Blanket.classify + rfl + +/-- The `cast .I32` of an I16 carries the same Int value. -/ +theorem L2_8c.cast_I32_val (x : Std.I16) : + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 x : Std.I32).val = x.val := by + exact Aeneas.Std.IScalar.val_mod_pow_greater_numBits Aeneas.Std.IScalarTy.I32 x (by decide) + +/-- `classify` is the identity on its `.ok` value (mirror of + `ntt_step_fc.classify_ok_eq` for use by the binomials proof). -/ +theorem L2_8c.classify_ok_eq {T : Type} (x : T) : + libcrux_secrets.traits.Classify.Blanket.classify x = .ok x := rfl + +/-- Reduction of `core.num.I32.wrapping_mul` to its `.ok` + representation in terms of the underlying `Std.I32.wrapping_mul`. -/ +theorem L2_8c.cm_wrapping_mul_i32_ok_eq (x y : Std.I32) : + CoreModels.core.num.I32.wrapping_mul x y = .ok (Aeneas.Std.I32.wrapping_mul x y) := by + unfold CoreModels.core.num.I32.wrapping_mul + unfold rust_primitives.arithmetic.wrapping_mul_i32 + rfl + +/-- Reduction of `core.num.I32.wrapping_add` to the underlying + Aeneas `Std.I32.wrapping_add`. -/ +theorem L2_8c.cm_wrapping_add_i32_ok_eq (x y : Std.I32) : + CoreModels.core.num.I32.wrapping_add x y = .ok (Aeneas.Std.I32.wrapping_add x y) := by + unfold CoreModels.core.num.I32.wrapping_add + unfold rust_primitives.arithmetic.wrapping_add_i32 + rfl + +/-- Reduction of `core.num.I16.wrapping_neg` to its `.ok` rep + via `Std.I16.wrapping_sub 0 x`. Mirror of `negate_per_elem_spec`. -/ +theorem L2_8c.cm_wrapping_neg_i16_ok_eq (x : Std.I16) : + CoreModels.core.num.I16.wrapping_neg x = .ok (Aeneas.Std.I16.wrapping_sub (0#i16) x) := by + unfold CoreModels.core.num.I16.wrapping_neg + unfold rust_primitives.arithmetic.wrapping_sub_i16 + rfl + +/-- I16 wrapping-neg is exact when |x.val| < 2^15. + `(wrapping_sub 0 x).val = -x.val` when `-x.val ∈ [-2^15, 2^15)`, + i.e. when `x.val ∈ (-2^15, 2^15]`. We use `≤ 2^15 - 1` (strictly + inside, away from boundary). -/ +theorem L2_8c.wrapping_neg_val_eq (x : Std.I16) + (h : x.val.natAbs ≤ 2^15 - 1) : + (Aeneas.Std.I16.wrapping_sub (0#i16) x).val = -x.val := by + rw [Aeneas.Std.I16.wrapping_sub_val_eq] + show Int.bmod ((0#i16 : Std.I16).val - x.val) (2^16) = -x.val + have h_zero : (0#i16 : Std.I16).val = 0 := by decide + rw [h_zero] + show Int.bmod (0 - x.val) (2^16) = -x.val + have h_lb : -(2^15 : Int) ≤ 0 - x.val := by + have : x.val ≤ (2^15 - 1 : Int) := by + have h_abs := h + have : x.val.natAbs ≤ 2^15 - 1 := h_abs + omega + omega + have h_ub : (0 - x.val : Int) < (2^15 : Int) := by + have : -(2^15 - 1 : Int) ≤ x.val := by + have h_abs := h + have : x.val.natAbs ≤ 2^15 - 1 := h_abs + omega + omega + have h_bmod : Int.bmod (0 - x.val) (2^16) = 0 - x.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + rw [h_bmod]; ring + +/-- I32 wrapping multiplication is exact under the no-overflow bound. -/ +theorem L2_8c.wrapping_mul_i32_no_overflow (x y : Std.I32) + (h : (x.val * y.val).natAbs < 2^31) : + (Aeneas.Std.I32.wrapping_mul x y).val = x.val * y.val := by + rw [Aeneas.Std.I32.wrapping_mul_val_eq] + have h_abs_lt : |x.val * y.val| < (2^31 : Int) := by + rw [Int.abs_eq_natAbs]; exact_mod_cast h + have h_lb : -(2^31 : Int) ≤ x.val * y.val := by + have := neg_abs_le (x.val * y.val) + have h1 : -|x.val * y.val| ≤ x.val * y.val := this + have h2 : -(2^31 : Int) < -|x.val * y.val| := by linarith + linarith + have h_ub : x.val * y.val < (2^31 : Int) := by + have := le_abs_self (x.val * y.val) + linarith + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 32 _ (by decide) + · have h_red : ((2 : Int)^(32-1)) = (2 : Int)^31 := by decide + rw [h_red]; exact h_lb + · have h_red : ((2 : Int)^(32-1)) = (2 : Int)^31 := by decide + rw [h_red]; exact h_ub + +/-- I32 wrapping addition is exact under the no-overflow bound. -/ +theorem L2_8c.wrapping_add_i32_no_overflow (x y : Std.I32) + (h : (x.val + y.val).natAbs < 2^31) : + (Aeneas.Std.I32.wrapping_add x y).val = x.val + y.val := by + rw [Aeneas.Std.I32.wrapping_add_val_eq] + have h_abs_lt : |x.val + y.val| < (2^31 : Int) := by + rw [Int.abs_eq_natAbs]; exact_mod_cast h + have h_lb : -(2^31 : Int) ≤ x.val + y.val := by + have := neg_abs_le (x.val + y.val) + linarith + have h_ub : x.val + y.val < (2^31 : Int) := by + have := le_abs_self (x.val + y.val) + linarith + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 32 _ (by decide) + · have h_red : ((2 : Int)^(32-1)) = (2 : Int)^31 := by decide + rw [h_red]; exact h_lb + · have h_red : ((2 : Int)^(32-1)) = (2 : Int)^31 := by decide + rw [h_red]; exact h_ub + +/-- Mont-domain variant of `lift_fe_neg_pure_eq`. Under the bound + `|a.val| ≤ 2^15 - 1` (boundary excluded), the I16 negation + `r` of `a` satisfies `lift_fe_mont r = neg_pure (lift_fe_mont a)`. -/ +theorem L2_8c.lift_fe_mont_neg_pure_eq + (a r : Std.I16) + (hbnd : a.val.natAbs ≤ 2^15 - 1) + (hrv : r.val = -a.val) : + lift_fe_mont r + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont a) := by + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont a) with hs_def + have h_lm_canon : libcrux_iot_ml_kem.Spec.Pure.Canonical (lift_fe_mont a) := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical + unfold hacspec_ml_kem.parameters.FIELD_MODULUS + show (lift_fe_mont a).val.val < 3329 + rw [lift_fe_mont_val_val] + exact ZMod.val_lt _ + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_neg_pure + (lift_fe_mont a) h_lm_canon + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + unfold hacspec_ml_kem.parameters.FIELD_MODULUS at h_cs; simpa using h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + -- LHS reduction. + have h_lhs : lift_fe_mont r = feOfZMod (-((a.val : ZMod 3329)) * 169) := by + unfold lift_fe_mont i16_to_spec_fe_mont + congr 1 + rw [hrv]; push_cast; ring + -- zmodOfFE s = -((a.val : ZMod q) * 169). + have h_lm_zmod : zmodOfFE (lift_fe_mont a) = (a.val : ZMod 3329) * 169 := by + unfold zmodOfFE + rw [lift_fe_mont_val_val] + rw [ZMod.natCast_zmod_val] + unfold i16_to_spec_fe_mont + rfl + -- Convert `(3329 - X : Nat)` as ZMod q to `-(X : ZMod q)`. + have h_nat_sub_zmod (X : Nat) (hX : X < 3329) : + (((3329 - X : Nat)) : ZMod 3329) = -((X : Nat) : ZMod 3329) := by + have h_sum_nat : (3329 - X : Nat) + X = 3329 := by omega + have h_sum_zmod : (((3329 - X : Nat) : ZMod 3329)) + ((X : ZMod 3329)) = 0 := by + rw [← Nat.cast_add, h_sum_nat]; exact ZMod.natCast_self 3329 + exact eq_neg_of_add_eq_zero_left h_sum_zmod + have h_zmod_s : zmodOfFE s = -((a.val : ZMod 3329) * 169) := by + unfold zmodOfFE + rw [neg_pure_val_eq _ h_lm_canon] + rw [ZMod.natCast_mod] + have h_lm_lt : (lift_fe_mont a).val.val < 3329 := by + rw [lift_fe_mont_val_val]; exact ZMod.val_lt _ + rw [h_nat_sub_zmod _ h_lm_lt] + -- Goal: -((lift_fe_mont a).val.val : ZMod q) = -((a.val : ZMod q) * 169). + rw [show ((lift_fe_mont a).val.val : ZMod 3329) = zmodOfFE (lift_fe_mont a) from by + unfold zmodOfFE; rfl] + rw [h_lm_zmod] + rw [h_lhs, ← h_round_trip, h_zmod_s] + congr 1; ring + +/-! ### ZMod 3329 projection lemmas — used by L2.8 / L6.3 / L7 closures. + These are pure-arithmetic facts about how `zmodOfFE` distributes + over the SpecPure FE operations and the lift functions. Factored + out of `mont_reduce_{even,odd}_fe_eq` so future closures (L6.3a/b/c, + L2.8d cache variants) can reuse them without inlining. -/ + +/-- `zmodOfFE` of `Spec.mont_reduce_pure ∘ lift_fe_int`: in ZMod 3329, + `mont_reduce_pure (lift_fe_int v) = v · 169²` (i.e., `v · R⁻²`). -/ +theorem L2_8c.zmodOfFE_mont_reduce_lift_fe_int (v : Int) : + zmodOfFE (Spec.mont_reduce_pure (lift_fe_int v)) + = (v : ZMod 3329) * 169 * 169 := by + rw [mont_reduce_pure_lift_fe_int] + rw [zmodOfFE_feOfZMod] + +/-- `zmodOfFE` of `lift_fe_mont`: in ZMod 3329, + `lift_fe_mont x = x.val · 169` (i.e., `x · R⁻¹`). -/ +theorem L2_8c.zmodOfFE_lift_fe_mont (x : Std.I16) : + zmodOfFE (lift_fe_mont x) = (x.val : ZMod 3329) * 169 := by + unfold lift_fe_mont + rw [zmodOfFE_feOfZMod] + rfl + +/-- `zmodOfFE` distributes over `FieldElement.mul_pure` in ZMod 3329. -/ +theorem L2_8c.zmodOfFE_mul_pure + (a b : hacspec_ml_kem.parameters.FieldElement) : + zmodOfFE (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b) + = zmodOfFE a * zmodOfFE b := by + unfold zmodOfFE + rw [mul_pure_val_eq] + rw [ZMod.natCast_mod] + push_cast + rfl + +/-- `zmodOfFE` distributes over `FieldElement.add_pure` in ZMod 3329. -/ +theorem L2_8c.zmodOfFE_add_pure + (a b : hacspec_ml_kem.parameters.FieldElement) : + zmodOfFE (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b) + = zmodOfFE a + zmodOfFE b := by + unfold zmodOfFE + rw [add_pure_val_eq] + rw [ZMod.natCast_mod] + push_cast + rfl + +set_option maxHeartbeats 400000 in +/-- Mont-domain FE equation builder for the L2.8c per-pair Triple: + if the new accumulator lane `r` (as I32) and the per-pair operands + (as I16) satisfy the ZMod 3329 modular equation + `r * 2^16 ≡ out * 2^16 + ai * bi * 2^16 + aj * bj * zeta (mod q)` + (the impl-side raw I32 equation projected to ZMod q), then the + FE-level equation + `mont_reduce_pure (lift_fe_int r.val) + = add_pure (mont_reduce_pure (lift_fe_int out.val)) + (add_pure (mul_pure ai_m bi_m) + (mul_pure (mul_pure aj_m bj_m) zeta_m))` + holds, where each `x_m = lift_fe_mont x`. + + Algebra: both sides reduce (via `mont_reduce_pure_lift_fe_int` + + add/mul_pure round-trip) to a `feOfZMod` of a ZMod q expression. + The Mont-inversion identity `2285 · 169 ≡ 1 (mod q)` (since + `2^16 ≡ 2285 (mod q)`, `R⁻¹ = 169`) collapses the powers; `ring` + closes after. -/ +theorem L2_8c.mont_reduce_even_fe_eq + (out r : Std.I32) (ai bi aj bj zeta : Std.I16) + (h_zmod : ((r.val * (2 ^ 16 : Int)) : ZMod 3329) + = ((out.val * (2 ^ 16 : Int) + ai.val * bi.val * (2 ^ 16 : Int) + + aj.val * bj.val * zeta.val) : ZMod 3329)) : + Spec.mont_reduce_pure (lift_fe_int r.val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bi)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont bj)) + (lift_fe_mont zeta))) := by + -- LHS: feOfZMod ((r.val : ZMod q) * 169 * 169). + rw [mont_reduce_pure_lift_fe_int] + -- RHS: round-trip via canonicity. + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bi)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont bj)) + (lift_fe_mont zeta))) with hs_def + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bi)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont bj)) + (lift_fe_mont zeta))) + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + -- Push s through the 4 zmodOfFE projection lemmas to get a pure ZMod expression. + have h_zmod_s : zmodOfFE s + = (out.val : ZMod 3329) * 169 * 169 + + ((ai.val : ZMod 3329) * 169 * ((bi.val : ZMod 3329) * 169) + + ((aj.val : ZMod 3329) * 169 * ((bj.val : ZMod 3329) * 169)) + * ((zeta.val : ZMod 3329) * 169)) := by + simp only [hs_def, + L2_8c.zmodOfFE_add_pure, + L2_8c.zmodOfFE_mont_reduce_lift_fe_int, + L2_8c.zmodOfFE_mul_pure, + L2_8c.zmodOfFE_lift_fe_mont] + -- Push h_zmod through ZMod, then `ring` closes. + -- Goal after rw [← h_round_trip]: feOfZMod (((r.val : Int) : ZMod 3329) * 169 * 169) = feOfZMod (zmodOfFE s). + rw [← h_round_trip, h_zmod_s] + -- Goal: feOfZMod (... LHS ...) = feOfZMod (... RHS ...). Closed by congr 1 + ring on h_zmod. + congr 1 + -- ZMod q equation: (r.val : ZMod q) * 169 * 169 = out * 169² + ai*169*(bi*169) + ((aj*169)*(bj*169))*(zeta*169). + -- From h_zmod: r.val * R = out * R + ai*bi*R + aj*bj*zeta. + -- The Mont-inversion identity: 2^16 * 169^2 ≡ 169 (mod q) and 2^16 * 169 ≡ 1 (mod q). + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + -- Push the cast `(2^16 : Int) : ZMod 3329` to 2285 in h_zmod. + push_cast at h_zmod + -- h_zmod : (r.val : ZMod q) * 2285 = out*2285 + ai*bi*2285 + aj*bj*zeta in ZMod q. + -- Strategy: rewrite h_zmod with `* 169 * 169 * 169` on both sides, then use + -- inv to collapse `2285 * 169 = 1` in each term, then `ring`. + have h_mul_169_cubed : + (r.val : ZMod 3329) * (2^16 : Int) * 169 * 169 * 169 + = ((out.val : ZMod 3329) * (2^16 : Int) + (ai.val : ZMod 3329) * (bi.val : ZMod 3329) * (2^16 : Int) + + (aj.val : ZMod 3329) * (bj.val : ZMod 3329) * (zeta.val : ZMod 3329)) * 169 * 169 * 169 := by + have := h_zmod + push_cast at this ⊢ + rw [this] + -- (2^16 : Int) : ZMod 3329 = 2285. + have h_2_16 : ((2^16 : Int) : ZMod 3329) = 2285 := by decide + rw [h_2_16] at h_mul_169_cubed + -- Now: r * 2285 * 169 * 169 * 169 = (out * 2285 + ai*bi*2285 + aj*bj*zeta) * 169 * 169 * 169. + -- LHS reduces: r * (2285*169) * 169 * 169 = r * 169 * 169 (using 2285*169 = 1). + -- We want: r * 169 * 169 = out*169*169 + ai*169*(bi*169) + (aj*169*(bj*169))*(zeta*169). + have h_lhs : + (r.val : ZMod 3329) * 169 * 169 + = (r.val : ZMod 3329) * 2285 * 169 * 169 * 169 := by + have : (r.val : ZMod 3329) * 169 * 169 = (r.val : ZMod 3329) * (2285 * 169) * 169 * 169 := by + rw [h_inv]; ring + rw [this]; ring + rw [h_lhs, h_mul_169_cubed] + -- Goal: (out*2285 + ai*bi*2285 + aj*bj*zeta) * 169 * 169 * 169 + -- = out*169*169 + (ai*169*(bi*169) + (aj*169*(bj*169))*(zeta*169)). + -- Reorganize LHS by extracting `2285 * (169*169*169) = 169*169`: + have h_expand : ((out.val : ZMod 3329) * 2285 + + (ai.val : ZMod 3329) * (bi.val : ZMod 3329) * 2285 + + (aj.val : ZMod 3329) * (bj.val : ZMod 3329) * (zeta.val : ZMod 3329)) + * 169 * 169 * 169 + = (out.val : ZMod 3329) * (2285 * (169 * 169 * 169)) + + (ai.val : ZMod 3329) * (bi.val : ZMod 3329) * (2285 * (169 * 169 * 169)) + + (aj.val : ZMod 3329) * (bj.val : ZMod 3329) * (zeta.val : ZMod 3329) * (169 * 169 * 169) := by + ring + have h_collapse : ((2285 : ZMod 3329)) * (169 * 169 * 169) = 169 * 169 := by decide + rw [h_expand, h_collapse] + ring + +set_option maxHeartbeats 400000 in +/-- Odd-half version of `mont_reduce_even_fe_eq`. -/ +theorem L2_8c.mont_reduce_odd_fe_eq + (out r : Std.I32) (ai bi aj bj : Std.I16) + (h_zmod : ((r.val * (2 ^ 16 : Int)) : ZMod 3329) + = ((out.val * (2 ^ 16 : Int) + + ai.val * bj.val * (2 ^ 16 : Int) + + aj.val * bi.val * (2 ^ 16 : Int)) : ZMod 3329)) : + Spec.mont_reduce_pure (lift_fe_int r.val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bj)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont bi))) := by + rw [mont_reduce_pure_lift_fe_int] + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bj)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont bi))) with hs_def + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bj)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont bi))) + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + have h_zmod_s : zmodOfFE s + = (out.val : ZMod 3329) * 169 * 169 + + ((ai.val : ZMod 3329) * 169 * ((bj.val : ZMod 3329) * 169) + + (aj.val : ZMod 3329) * 169 * ((bi.val : ZMod 3329) * 169)) := by + simp only [hs_def, + L2_8c.zmodOfFE_add_pure, + L2_8c.zmodOfFE_mont_reduce_lift_fe_int, + L2_8c.zmodOfFE_mul_pure, + L2_8c.zmodOfFE_lift_fe_mont] + rw [← h_round_trip, h_zmod_s] + congr 1 + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + -- Multiply h_zmod by 169^3 on both sides; cast (2^16 : Int) : ZMod q = 2285; + -- collapse 2285*169 = 1 to leave 169^2 multipliers. + have h_mul_169_cubed : + (r.val : ZMod 3329) * (2^16 : Int) * 169 * 169 * 169 + = ((out.val : ZMod 3329) * (2^16 : Int) + + (ai.val : ZMod 3329) * (bj.val : ZMod 3329) * (2^16 : Int) + + (aj.val : ZMod 3329) * (bi.val : ZMod 3329) * (2^16 : Int)) * 169 * 169 * 169 := by + have := h_zmod + push_cast at this ⊢ + rw [this] + have h_2_16 : ((2^16 : Int) : ZMod 3329) = 2285 := by decide + rw [h_2_16] at h_mul_169_cubed + have h_lhs : + (r.val : ZMod 3329) * 169 * 169 + = (r.val : ZMod 3329) * 2285 * 169 * 169 * 169 := by + have : (r.val : ZMod 3329) * 169 * 169 = (r.val : ZMod 3329) * (2285 * 169) * 169 * 169 := by + rw [h_inv]; ring + rw [this]; ring + rw [h_lhs, h_mul_169_cubed] + have : ((out.val : ZMod 3329) * 2285 + + (ai.val : ZMod 3329) * (bj.val : ZMod 3329) * 2285 + + (aj.val : ZMod 3329) * (bi.val : ZMod 3329) * 2285) + * 169 * 169 * 169 + = (out.val : ZMod 3329) * (2285 * (169 * 169 * 169)) + + (ai.val : ZMod 3329) * (bj.val : ZMod 3329) * (2285 * (169 * 169 * 169)) + + (aj.val : ZMod 3329) * (bi.val : ZMod 3329) * (2285 * (169 * 169 * 169)) := by + ring + rw [this] + rw [show ((2285 : ZMod 3329)) * (169 * 169 * 169) = 169 * 169 from by decide] + ring + +/-- Bound-propagation step for the L2.8c (and L2.8d) chained binomial + composition: given a per-pair-update relation between `prev` and + `next` slices (untouched lanes equal, touched lanes bounded), and + a universal bound on `prev`, conclude the universal bound on `next`. + + Refactored from the 8-fold 16-way `interval_cases` boilerplate in + the original L2.8c body. Each step now uses a 4-arg invocation + instead of a 20-line case split. Also reusable by L2.8d cache + variants (same impl structure: 8 binomial-pair updates). -/ +theorem L2_8c.bnd_universal_step + (prev next : Aeneas.Std.Slice Std.I32) (i : Nat) (hi : i < 8) + (h_prev_universal : ∀ k : Fin 16, + (prev.val[k.val]!).val.natAbs ≤ 2^30 + 2^25) + (h_unc : ∀ k : Nat, k < 16 → k ≠ 2 * i → k ≠ 2 * i + 1 → + next.val[k]! = prev.val[k]!) + (h_at_even : (next.val[2 * i]!).val.natAbs ≤ 2^30 + 2^25) + (h_at_odd : (next.val[2 * i + 1]!).val.natAbs ≤ 2^30 + 2^25) : + ∀ k : Fin 16, (next.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := by + intro k + by_cases h1 : k.val = 2 * i + · rw [show k.val = 2 * i from h1]; exact h_at_even + · by_cases h2 : k.val = 2 * i + 1 + · rw [show k.val = 2 * i + 1 from h2]; exact h_at_odd + · rw [h_unc k.val k.isLt h1 h2]; exact h_prev_universal k + +set_option maxHeartbeats 8000000 in +/-- Per-pair Triple for `accumulating_ntt_multiply_binomials`. Models the + impl's per-pair contribution to the accumulator: reads `a[2i], a[2i+1]`, + `b[2i], b[2i+1]`, multiplies + Montgomery-reduces to form an even and + odd I32 delta, then `wrapping_add`s onto `out[2i], out[2i+1]`. + + POST exposes: + - `r.length = 16` (Slice.update preserves length); + - Untouched-lane preservation outside `{2i, 2i+1}`; + - Relative bound: `|r.val[2i]!| ≤ |out.val[2i]!| + 2^25` (and 2i+1); + - FE equation: the Mont-domain per-pair update agrees with + `add (mont_reduce_pure old) (add (mul a₀_m b₀_m) (mul (mul a₁_m b₁_m) zeta_m))` + for the even half, and the odd-half analog. + + Helper for L2.8c — 8 chained applications give + `accumulating_ntt_multiply_fc`. -/ +theorem accumulating_ntt_multiply_binomials_fc + (a b : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i : Std.Usize) + (out : Aeneas.Std.Slice Std.I32) + (h_i : i.val < 8) + (h_out_len : out.length = 16) + (h_a : ∀ j : Fin 16, (a.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_b : ∀ j : Fin 16, (b.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_zeta : zeta.val.natAbs ≤ 1664) + (h_out_bnd : ∀ k : Fin 16, (out.val[k.val]!).val.natAbs ≤ 2^30 + 2^25) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials + a b zeta i out + ⦃ ⇓ r => ⌜ r.length = 16 + ∧ (∀ k : Nat, k < 16 → k ≠ 2 * i.val → k ≠ 2 * i.val + 1 → + r.val[k]! = out.val[k]!) + ∧ (r.val[2 * i.val]!).val.natAbs + ≤ (out.val[2 * i.val]!).val.natAbs + 2^25 + ∧ (r.val[2 * i.val + 1]!).val.natAbs + ≤ (out.val[2 * i.val + 1]!).val.natAbs + 2^25 + ∧ Spec.mont_reduce_pure (lift_fe_int (r.val[2 * i.val]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (out.val[2 * i.val]!).val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val]!)) + (lift_fe_mont (b.elements.val[2 * i.val]!))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val + 1]!)) + (lift_fe_mont (b.elements.val[2 * i.val + 1]!))) + (lift_fe_mont zeta))) + ∧ Spec.mont_reduce_pure (lift_fe_int (r.val[2 * i.val + 1]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (out.val[2 * i.val + 1]!).val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val]!)) + (lift_fe_mont (b.elements.val[2 * i.val + 1]!))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val + 1]!)) + (lift_fe_mont (b.elements.val[2 * i.val]!)))) ⌝ ⦄ := by + -- ===== Setup ===== + have h_2i_lt : 2 * i.val < 16 := by omega + have h_2i1_lt : 2 * i.val + 1 < 16 := by omega + have h_a_len : a.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length a + have h_b_len : b.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length b + have h_out_val_len : out.val.length = 16 := h_out_len + -- Set up bound abbreviations. + set ai_v : Std.I16 := a.elements.val[2 * i.val]! with hai_def + set bi_v : Std.I16 := b.elements.val[2 * i.val]! with hbi_def + set aj_v : Std.I16 := a.elements.val[2 * i.val + 1]! with haj_def + set bj_v : Std.I16 := b.elements.val[2 * i.val + 1]! with hbj_def + have h_ai : ai_v.val.natAbs ≤ 3328 := h_a ⟨2 * i.val, h_2i_lt⟩ + have h_bi : bi_v.val.natAbs ≤ 3328 := h_b ⟨2 * i.val, h_2i_lt⟩ + have h_aj : aj_v.val.natAbs ≤ 3328 := h_a ⟨2 * i.val + 1, h_2i1_lt⟩ + have h_bj : bj_v.val.natAbs ≤ 3328 := h_b ⟨2 * i.val + 1, h_2i1_lt⟩ + set old_e : Std.I32 := out.val[2 * i.val]! with hoe_def + set old_o : Std.I32 := out.val[2 * i.val + 1]! with hoo_def + have h_old_e_bnd : old_e.val.natAbs ≤ 2^30 + 2^25 := h_out_bnd ⟨2 * i.val, h_2i_lt⟩ + have h_old_o_bnd : old_o.val.natAbs ≤ 2^30 + 2^25 := h_out_bnd ⟨2 * i.val + 1, h_2i1_lt⟩ + -- ===== Index arithmetic ===== + obtain ⟨i1, h_i1_eq, h_i1_val⟩ := + usize_mul_ok_eq_fc 2#usize i (by scalar_tac) + have h_i1_val' : i1.val = 2 * i.val := by + rw [h_i1_val]; rfl + obtain ⟨i2, h_i2_eq, h_i2_val⟩ := + usize_add_ok_eq_fc i1 1#usize (by scalar_tac) + have h_i2_val' : i2.val = 2 * i.val + 1 := by + rw [h_i2_val, h_i1_val']; rfl + -- ===== Reads (with index_usize_ok_eq) ===== + have h_read_ai : + Aeneas.Std.Array.index_usize a.elements i1 = .ok ai_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a.elements i1 + (by rw [h_a_len, h_i1_val']; exact h_2i_lt) + rw [h, h_i1_val'] + have h_read_bi : + Aeneas.Std.Array.index_usize b.elements i1 = .ok bi_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq b.elements i1 + (by rw [h_b_len, h_i1_val']; exact h_2i_lt) + rw [h, h_i1_val'] + have h_read_aj : + Aeneas.Std.Array.index_usize a.elements i2 = .ok aj_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a.elements i2 + (by rw [h_a_len, h_i2_val']; exact h_2i1_lt) + rw [h, h_i2_val'] + have h_read_bj : + Aeneas.Std.Array.index_usize b.elements i2 = .ok bj_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq b.elements i2 + (by rw [h_b_len, h_i2_val']; exact h_2i1_lt) + rw [h, h_i2_val'] + -- ===== as_i32 casts ===== + set ai32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 ai_v with hai32_def + set bi32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 bi_v with hbi32_def + set aj32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 aj_v with haj32_def + set bj32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 bj_v with hbj32_def + set zeta32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 zeta with hzeta32_def + have h_ai32_val : ai32.val = ai_v.val := L2_8c.cast_I32_val ai_v + have h_bi32_val : bi32.val = bi_v.val := L2_8c.cast_I32_val bi_v + have h_aj32_val : aj32.val = aj_v.val := L2_8c.cast_I32_val aj_v + have h_bj32_val : bj32.val = bj_v.val := L2_8c.cast_I32_val bj_v + have h_zeta32_val : zeta32.val = zeta.val := L2_8c.cast_I32_val zeta + -- as_i32 → .ok cast. + have h_as_ai : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 ai_v = .ok ai32 := + L2_8c.as_i32_val_eq ai_v + have h_as_bi : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 bi_v = .ok bi32 := + L2_8c.as_i32_val_eq bi_v + have h_as_aj : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 aj_v = .ok aj32 := + L2_8c.as_i32_val_eq aj_v + have h_as_bj : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 bj_v = .ok bj32 := + L2_8c.as_i32_val_eq bj_v + have h_as_zeta : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 zeta = .ok zeta32 := + L2_8c.as_i32_val_eq zeta + -- ===== Step: ai_bi = wrapping_mul ai32 bi32, value = ai.val * bi.val ===== + set ai_bi : Std.I32 := Aeneas.Std.I32.wrapping_mul ai32 bi32 with habi_def + have h_ai_bi_eq : CoreModels.core.num.I32.wrapping_mul ai32 bi32 = .ok ai_bi := + L2_8c.cm_wrapping_mul_i32_ok_eq ai32 bi32 + have h_ai_bi_val : ai_bi.val = ai_v.val * bi_v.val := by + have h_bnd : (ai32.val * bi32.val).natAbs < 2^31 := by + rw [h_ai32_val, h_bi32_val] + have h := Int.natAbs_mul ai_v.val bi_v.val + have : ai_v.val.natAbs * bi_v.val.natAbs ≤ 3328 * 3328 := by + exact Nat.mul_le_mul h_ai h_bi + rw [h] + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow ai32 bi32 h_bnd + rw [this, h_ai32_val, h_bi32_val] + -- ===== Step: bj_zeta_ = wrapping_mul bj32 zeta32, value = bj.val * zeta.val ===== + set bj_zeta_ : Std.I32 := Aeneas.Std.I32.wrapping_mul bj32 zeta32 with hbjz_def + have h_bj_zeta_eq : CoreModels.core.num.I32.wrapping_mul bj32 zeta32 = .ok bj_zeta_ := + L2_8c.cm_wrapping_mul_i32_ok_eq bj32 zeta32 + have h_bj_zeta_val : bj_zeta_.val = bj_v.val * zeta.val := by + have h_bnd : (bj32.val * zeta32.val).natAbs < 2^31 := by + rw [h_bj32_val, h_zeta32_val] + rw [Int.natAbs_mul] + have h_mul : bj_v.val.natAbs * zeta.val.natAbs ≤ 3328 * 1664 := + Nat.mul_le_mul h_bj h_zeta + have : (3328 * 1664 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow bj32 zeta32 h_bnd + rw [this, h_bj32_val, h_zeta32_val] + -- ===== Step: bj_zeta = montgomery_reduce_element bj_zeta_, |bj_zeta| ≤ 4993 ===== + have h_bj_zeta_pre : bj_zeta_.val.natAbs ≤ 2^16 * 3328 := by + rw [h_bj_zeta_val] + rw [Int.natAbs_mul] + have h_mul : bj_v.val.natAbs * zeta.val.natAbs ≤ 3328 * 1664 := + Nat.mul_le_mul h_bj h_zeta + have : (3328 * 1664 : Nat) ≤ 2^16 * 3328 := by decide + omega + obtain ⟨bj_zeta, h_bj_zeta_ok, h_bj_zeta_bnd, h_bj_zeta_lift⟩ := + triple_exists_ok_fc (montgomery_reduce_element_fc bj_zeta_ h_bj_zeta_pre) + -- Also recover the legacy modq form: bj_zeta * 2^16 ≡ bj_zeta_ (mod q). + -- We get it via the legacy spec. + have h_bj_zeta_pre' : bj_zeta_.val.natAbs ≤ 3328 * 2^16 := by + rw [show (3328 * 2^16 : Nat) = 2^16 * 3328 from by decide]; exact h_bj_zeta_pre + obtain ⟨bj_zeta', h_bj_zeta_ok', _h_bnd', _h_tight, h_bj_zeta_modq⟩ := + triple_exists_ok_fc + (libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_reduce_element_spec bj_zeta_ h_bj_zeta_pre') + have h_bj_zeta_eq2 : bj_zeta = bj_zeta' := by + have h_both : (Result.ok bj_zeta : Result _) = Result.ok bj_zeta' := by + rw [← h_bj_zeta_ok, h_bj_zeta_ok'] + cases h_both; rfl + -- ===== Step: aj_bj_zeta = wrapping_mul aj32 (as_i32 bj_zeta), value = aj.val * bj_zeta.val ===== + set bj_zeta32 : Std.I32 := + Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 bj_zeta with hbjz32_def + have h_bj_zeta32_val : bj_zeta32.val = bj_zeta.val := L2_8c.cast_I32_val bj_zeta + have h_as_bj_zeta : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 bj_zeta + = .ok bj_zeta32 := L2_8c.as_i32_val_eq bj_zeta + set aj_bj_zeta : Std.I32 := Aeneas.Std.I32.wrapping_mul aj32 bj_zeta32 with habjz_def + have h_aj_bj_zeta_eq : CoreModels.core.num.I32.wrapping_mul aj32 bj_zeta32 = .ok aj_bj_zeta := + L2_8c.cm_wrapping_mul_i32_ok_eq aj32 bj_zeta32 + have h_aj_bj_zeta_val : aj_bj_zeta.val = aj_v.val * bj_zeta.val := by + have h_bnd : (aj32.val * bj_zeta32.val).natAbs < 2^31 := by + rw [h_aj32_val, h_bj_zeta32_val, Int.natAbs_mul] + have h_mul : aj_v.val.natAbs * bj_zeta.val.natAbs ≤ 3328 * (3328 + 1665) := + Nat.mul_le_mul h_aj h_bj_zeta_bnd + have : (3328 * (3328 + 1665) : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow aj32 bj_zeta32 h_bnd + rw [this, h_aj32_val, h_bj_zeta32_val] + -- ===== Step: ai_bi_aj_bj = wrapping_add ai_bi aj_bj_zeta ===== + set ai_bi_aj_bj : Std.I32 := Aeneas.Std.I32.wrapping_add ai_bi aj_bj_zeta with hsum_e_def + have h_sum_e_eq : CoreModels.core.num.I32.wrapping_add ai_bi aj_bj_zeta = .ok ai_bi_aj_bj := + L2_8c.cm_wrapping_add_i32_ok_eq ai_bi aj_bj_zeta + -- Even-delta bound: |ai*bi + aj*bj_zeta| ≤ 3328² + 3328·4993 ≤ 2^25 (precise: ~28M < 33.5M). + have h_sum_e_bnd : (ai_bi.val + aj_bj_zeta.val).natAbs ≤ 3328 * 3328 + 3328 * (3328 + 1665) := by + rw [h_ai_bi_val, h_aj_bj_zeta_val] + have h_e1 : (ai_v.val * bi_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_ai h_bi + have h_e2 : (aj_v.val * bj_zeta.val).natAbs ≤ 3328 * (3328 + 1665) := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_aj h_bj_zeta_bnd + have h_tri : ((ai_v.val * bi_v.val) + (aj_v.val * bj_zeta.val)).natAbs + ≤ (ai_v.val * bi_v.val).natAbs + (aj_v.val * bj_zeta.val).natAbs := + Int.natAbs_add_le _ _ + omega + have h_sum_e_val : ai_bi_aj_bj.val = ai_bi.val + aj_bj_zeta.val := by + have h_bnd : (ai_bi.val + aj_bj_zeta.val).natAbs < 2^31 := by + have h_le : (3328 * 3328 + 3328 * (3328 + 1665) : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow ai_bi aj_bj_zeta h_bnd + -- Bound the delta_even by 2^25: + have h_delta_e_bnd : ai_bi_aj_bj.val.natAbs ≤ 2^25 := by + rw [h_sum_e_val] + have : (3328 * 3328 + 3328 * (3328 + 1665) : Nat) ≤ 2^25 := by decide + omega + -- ===== Step: ai_bj = wrapping_mul ai32 bj32, value = ai*bj ===== + set ai_bj_p : Std.I32 := Aeneas.Std.I32.wrapping_mul ai32 bj32 with haibj_def + have h_ai_bj_eq : CoreModels.core.num.I32.wrapping_mul ai32 bj32 = .ok ai_bj_p := + L2_8c.cm_wrapping_mul_i32_ok_eq ai32 bj32 + have h_ai_bj_val : ai_bj_p.val = ai_v.val * bj_v.val := by + have h_bnd : (ai32.val * bj32.val).natAbs < 2^31 := by + rw [h_ai32_val, h_bj32_val, Int.natAbs_mul] + have h_mul : ai_v.val.natAbs * bj_v.val.natAbs ≤ 3328 * 3328 := + Nat.mul_le_mul h_ai h_bj + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow ai32 bj32 h_bnd + rw [this, h_ai32_val, h_bj32_val] + -- ===== Step: aj_bi = wrapping_mul aj32 bi32 ===== + set aj_bi_p : Std.I32 := Aeneas.Std.I32.wrapping_mul aj32 bi32 with hajbi_def + have h_aj_bi_eq : CoreModels.core.num.I32.wrapping_mul aj32 bi32 = .ok aj_bi_p := + L2_8c.cm_wrapping_mul_i32_ok_eq aj32 bi32 + have h_aj_bi_val : aj_bi_p.val = aj_v.val * bi_v.val := by + have h_bnd : (aj32.val * bi32.val).natAbs < 2^31 := by + rw [h_aj32_val, h_bi32_val, Int.natAbs_mul] + have h_mul : aj_v.val.natAbs * bi_v.val.natAbs ≤ 3328 * 3328 := + Nat.mul_le_mul h_aj h_bi + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow aj32 bi32 h_bnd + rw [this, h_aj32_val, h_bi32_val] + -- ===== Step: ai_bj_aj_bi = wrapping_add ai_bj aj_bi, value = ai*bj + aj*bi ===== + set ai_bj_aj_bi : Std.I32 := Aeneas.Std.I32.wrapping_add ai_bj_p aj_bi_p with hsum_o_def + have h_sum_o_eq : CoreModels.core.num.I32.wrapping_add ai_bj_p aj_bi_p = .ok ai_bj_aj_bi := + L2_8c.cm_wrapping_add_i32_ok_eq ai_bj_p aj_bi_p + have h_sum_o_bnd : (ai_bj_p.val + aj_bi_p.val).natAbs ≤ 2 * 3328 * 3328 := by + rw [h_ai_bj_val, h_aj_bi_val] + have h_e1 : (ai_v.val * bj_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_ai h_bj + have h_e2 : (aj_v.val * bi_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_aj h_bi + have h_tri := Int.natAbs_add_le (ai_v.val * bj_v.val) (aj_v.val * bi_v.val) + omega + have h_sum_o_val : ai_bj_aj_bi.val = ai_bj_p.val + aj_bi_p.val := by + have h_bnd : (ai_bj_p.val + aj_bi_p.val).natAbs < 2^31 := by + have : (2 * 3328 * 3328 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow ai_bj_p aj_bi_p h_bnd + have h_delta_o_bnd : ai_bj_aj_bi.val.natAbs ≤ 2^25 := by + rw [h_sum_o_val] + have : (2 * 3328 * 3328 : Nat) ≤ 2^25 := by decide + omega + -- ===== Slice reads + writes for `out` ===== + -- Step: i10 = out[i1] (= old_e at i1.val = 2*i.val). + have h_read_old_e : Aeneas.Std.Slice.index_usize out i1 = .ok old_e := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq out i1 + (by rw [h_out_val_len, h_i1_val']; exact h_2i_lt) + rw [h, h_i1_val'] + -- Step: i11 = wrapping_add old_e ai_bi_aj_bj (the new lane 2i value). + set new_e : Std.I32 := Aeneas.Std.I32.wrapping_add old_e ai_bi_aj_bj with hne_def + have h_new_e_eq : CoreModels.core.num.I32.wrapping_add old_e ai_bi_aj_bj = .ok new_e := + L2_8c.cm_wrapping_add_i32_ok_eq old_e ai_bi_aj_bj + -- new_e.val = old_e.val + delta_e (no overflow: |old_e| ≤ 2^30, |delta_e| ≤ 2^25). + have h_new_e_val : new_e.val = old_e.val + ai_bi_aj_bj.val := by + have h_bnd : (old_e.val + ai_bi_aj_bj.val).natAbs < 2^31 := by + have h_tri := Int.natAbs_add_le old_e.val ai_bi_aj_bj.val + have : (2^30 + 2^25 + 2^25 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow old_e ai_bi_aj_bj h_bnd + have h_new_e_bnd : new_e.val.natAbs ≤ old_e.val.natAbs + 2^25 := by + rw [h_new_e_val] + have h_tri := Int.natAbs_add_le old_e.val ai_bi_aj_bj.val + omega + -- Step: out1 = Slice.update out i1 new_e (= out.set i1 new_e). + have h_upd_e : Aeneas.Std.Slice.update out i1 new_e = .ok (out.set i1 new_e) := by + have hT := Aeneas.Std.Slice.update_spec out i1 new_e (by rw [h_out_len, h_i1_val']; exact h_2i_lt) + obtain ⟨v', h_eq, h_v'⟩ := Aeneas.Std.WP.spec_imp_exists hT + rw [h_eq, h_v'] + set out1 : Aeneas.Std.Slice Std.I32 := out.set i1 new_e with hout1_def + -- The impl computes `i12 = i1 + 1#usize` again (extracted as identical + -- to i2). After `simp only [h_i2_eq]` in the body composition, all four + -- `i1 + 1#usize` occurrences collapse to i2. So we state subsequent + -- reads/writes directly with i2. + have h_out1_len : out1.length = 16 := by simp [hout1_def]; exact h_out_len + have h_out1_val_len : out1.val.length = 16 := h_out1_len + have h_old_o_in_out1 : out1.val[i2.val]! = old_o := by + have h_set_val : out1.val = out.val.set i1.val new_e := by + simp [hout1_def, Aeneas.Std.Slice.set_val_eq] + have h_ne : 2 * i.val + 1 ≠ i1.val := by rw [h_i1_val']; omega + have h_lt : 2 * i.val + 1 < out.val.length := by rw [h_out_val_len]; exact h_2i1_lt + rw [h_set_val, h_i2_val', hoo_def] + have h_lt_set : 2 * i.val + 1 < (out.val.set i1.val new_e).length := by + rw [List.length_set]; exact h_lt + rw [getElem!_pos (out.val.set i1.val new_e) (2 * i.val + 1) h_lt_set] + rw [getElem!_pos out.val (2 * i.val + 1) h_lt] + rw [List.getElem_set_ne (Ne.symm h_ne)] + have h_read_old_o : Aeneas.Std.Slice.index_usize out1 i2 = .ok old_o := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq out1 i2 + (by rw [h_out1_val_len, h_i2_val']; exact h_2i1_lt) + rw [h, h_old_o_in_out1] + -- Step: i14 = wrapping_add old_o ai_bj_aj_bi (new lane 2i+1 value). + set new_o : Std.I32 := Aeneas.Std.I32.wrapping_add old_o ai_bj_aj_bi with hno_def + have h_new_o_eq : CoreModels.core.num.I32.wrapping_add old_o ai_bj_aj_bi = .ok new_o := + L2_8c.cm_wrapping_add_i32_ok_eq old_o ai_bj_aj_bi + have h_new_o_val : new_o.val = old_o.val + ai_bj_aj_bi.val := by + have h_bnd : (old_o.val + ai_bj_aj_bi.val).natAbs < 2^31 := by + have h_tri := Int.natAbs_add_le old_o.val ai_bj_aj_bi.val + have : (2^30 + 2^25 + 2^25 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow old_o ai_bj_aj_bi h_bnd + have h_new_o_bnd : new_o.val.natAbs ≤ old_o.val.natAbs + 2^25 := by + rw [h_new_o_val] + have h_tri := Int.natAbs_add_le old_o.val ai_bj_aj_bi.val + omega + have h_upd_o : Aeneas.Std.Slice.update out1 i2 new_o = .ok (out1.set i2 new_o) := by + have hT := Aeneas.Std.Slice.update_spec out1 i2 new_o + (by rw [h_out1_len, h_i2_val']; exact h_2i1_lt) + obtain ⟨v', h_eq, h_v'⟩ := Aeneas.Std.WP.spec_imp_exists hT + rw [h_eq, h_v'] + set out2 : Aeneas.Std.Slice Std.I32 := out1.set i2 new_o with hout2_def + -- ===== Compose the monadic chain ===== + -- The four `i1 + 1#usize` invocations all yield i2 (same Lean expression). + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials + a b zeta i out = .ok out2 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials + simp only [h_i1_eq, h_i2_eq, h_read_ai, h_read_bi, h_read_aj, h_read_bj, + h_as_ai, h_as_bi, h_as_aj, h_as_bj, h_as_zeta, h_as_bj_zeta, + h_ai_bi_eq, h_bj_zeta_eq, h_bj_zeta_ok, h_aj_bj_zeta_eq, + h_sum_e_eq, h_ai_bj_eq, h_aj_bi_eq, h_sum_o_eq, + h_read_old_e, h_new_e_eq, h_upd_e, + h_read_old_o, h_new_o_eq, h_upd_o, + Aeneas.Std.bind_tc_ok] + apply triple_of_ok_fc h_body + -- ===== POST: 6-conjunct ===== + -- Useful: out2.val unfolding. + have h_out2_val : out2.val = (out.val.set i1.val new_e).set i2.val new_o := by + show ((out.set i1 new_e).set i2 new_o).val = _ + rw [Aeneas.Std.Slice.set_val_eq, Aeneas.Std.Slice.set_val_eq] + have h_out2_len : out2.length = 16 := by + show ((out.set i1 new_e).set i2 new_o).length = 16 + rw [Aeneas.Std.Slice.set_length, Aeneas.Std.Slice.set_length]; exact h_out_len + have h_out2_val_len : out2.val.length = 16 := h_out2_len + -- Out2 at 2*i (= i1.val) = new_e. Out2 at 2*i+1 (= i2.val) = new_o. + have h_out2_at_2i : out2.val[2 * i.val]! = new_e := by + rw [h_out2_val, ← h_i1_val'] + have h_lt_out : i1.val < out.val.length := by rw [h_out_val_len, h_i1_val']; exact h_2i_lt + have h_lt1 : i1.val < (out.val.set i1.val new_e).length := by + rw [List.length_set]; exact h_lt_out + have h_lt2 : i1.val < ((out.val.set i1.val new_e).set i2.val new_o).length := by + rw [List.length_set]; exact h_lt1 + rw [getElem!_pos ((out.val.set i1.val new_e).set i2.val new_o) i1.val h_lt2] + rw [List.getElem_set_ne (by rw [h_i2_val', h_i1_val']; omega)] + rw [List.getElem_set_self] + have h_out2_at_2i1 : out2.val[2 * i.val + 1]! = new_o := by + rw [h_out2_val, ← h_i2_val'] + have h_lt_out : i2.val < out.val.length := by rw [h_out_val_len, h_i2_val']; exact h_2i1_lt + have h_lt1 : i2.val < (out.val.set i1.val new_e).length := by + rw [List.length_set]; exact h_lt_out + have h_lt2 : i2.val < ((out.val.set i1.val new_e).set i2.val new_o).length := by + rw [List.length_set]; exact h_lt1 + rw [getElem!_pos ((out.val.set i1.val new_e).set i2.val new_o) i2.val h_lt2] + rw [List.getElem_set_self] + -- Untouched: for k ∉ {2i, 2i+1}, out2.val[k]! = out.val[k]!. + have h_out2_untouched : ∀ k : Nat, k < 16 → k ≠ 2 * i.val → k ≠ 2 * i.val + 1 → + out2.val[k]! = out.val[k]! := by + intro k hk hki hkj + rw [h_out2_val] + have h_lt_out : k < out.val.length := by rw [h_out_val_len]; exact hk + have h_lt1 : k < (out.val.set i1.val new_e).length := by rw [List.length_set]; exact h_lt_out + have h_lt2 : k < ((out.val.set i1.val new_e).set i2.val new_o).length := by + rw [List.length_set]; exact h_lt1 + rw [getElem!_pos ((out.val.set i1.val new_e).set i2.val new_o) k h_lt2] + rw [getElem!_pos out.val k h_lt_out] + rw [List.getElem_set_ne (by rw [h_i2_val']; omega)] + rw [List.getElem_set_ne (by rw [h_i1_val']; omega)] + -- Now produce the 6-conjunct. + refine ⟨h_out2_len, ?_, ?_, ?_, ?_, ?_⟩ + · -- Untouched lanes. + exact h_out2_untouched + · -- Bound at 2*i. + rw [h_out2_at_2i] + -- new_e.val.natAbs ≤ old_e.val.natAbs + 2^25; old_e = out.val[2*i]!. + rw [hoe_def] at h_new_e_bnd + exact h_new_e_bnd + · -- Bound at 2*i+1. + rw [h_out2_at_2i1] + rw [hoo_def] at h_new_o_bnd + exact h_new_o_bnd + · -- FE eq (even half). + rw [h_out2_at_2i, hoe_def] + -- Goal: mont_reduce_pure (lift_fe_int new_e.val) = ... + -- Convert modq form `bj_zeta'.val ≡ bj_zeta_.val * 169` into ZMod eq. + have h_modq_cast : ((bj_zeta'.val : Int) : ZMod 3329) + = ((bj_zeta_.val * 169 : Int) : ZMod 3329) := + modq_eq_cast_zmod _ _ h_bj_zeta_modq + rw [h_bj_zeta_eq2.symm] at h_modq_cast + rw [h_bj_zeta_val] at h_modq_cast + push_cast at h_modq_cast + -- h_modq_cast : (bj_zeta.val : ZMod 3329) = (bj_v.val : ZMod q) * zeta.val * 169. + apply L2_8c.mont_reduce_even_fe_eq + (out := out.val[2 * i.val]!) (r := new_e) + (ai := ai_v) (bi := bi_v) (aj := aj_v) (bj := bj_v) (zeta := zeta) + -- Goal: (new_e.val * 2^16 : ZMod q) = (out * 2^16 + ai*bi*2^16 + aj*bj*zeta : ZMod q). + rw [← hoe_def, h_new_e_val, h_sum_e_val, h_ai_bi_val, h_aj_bj_zeta_val] + push_cast + -- LHS: (old_e + ai*bi + aj*bj_zeta) * 2^16 in ZMod q. + -- Use h_modq_cast to substitute bj_zeta.val = bj.val * zeta.val * 169. + rw [h_modq_cast] + -- 2^16 * 169 ≡ 1 (mod q), so 2285 * 169 = 1 in ZMod q. + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + -- Algebraic identity: (old + ai*bi + aj*(bj*zeta*169)) * 2285 + -- = old*2285 + ai*bi*2285 + aj*bj*zeta*(2285*169) + -- = old*2285 + ai*bi*2285 + aj*bj*zeta. + calc ((old_e.val : ZMod 3329) + ((ai_v.val : ZMod 3329) * (bi_v.val : ZMod 3329) + + (aj_v.val : ZMod 3329) * ((bj_v.val : ZMod 3329) * (zeta.val : ZMod 3329) * 169))) + * 2285 + = (old_e.val : ZMod 3329) * 2285 + + (ai_v.val : ZMod 3329) * (bi_v.val : ZMod 3329) * 2285 + + (aj_v.val : ZMod 3329) * (bj_v.val : ZMod 3329) * (zeta.val : ZMod 3329) + * (2285 * 169) := by ring + _ = (old_e.val : ZMod 3329) * 2285 + + (ai_v.val : ZMod 3329) * (bi_v.val : ZMod 3329) * 2285 + + (aj_v.val : ZMod 3329) * (bj_v.val : ZMod 3329) * (zeta.val : ZMod 3329) := by + rw [h_inv]; ring + · -- FE eq (odd half). + rw [h_out2_at_2i1, hoo_def] + apply L2_8c.mont_reduce_odd_fe_eq + (out := out.val[2 * i.val + 1]!) (r := new_o) + (ai := ai_v) (bi := bi_v) (aj := aj_v) (bj := bj_v) + rw [← hoo_def, h_new_o_val, h_sum_o_val, h_ai_bj_val, h_aj_bi_val] + push_cast + ring + + +set_option maxHeartbeats 4000000 in +/-- L2.8 — `vector.portable.ntt.accumulating_ntt_multiply`: base-case + NTT-domain multiply on a 16-lane vector chunk. + + The impl chains 8 + `accumulating_ntt_multiply_binomials` calls, each accumulating one + coefficient pair via the degree-2 polynomial multiply mod (X²−ζ²). + The 4 input zetas yield 8 effective zetas with alternating + positive/negative signs across consecutive pair positions. + + POST defers algebraic shape to `ntt_multiply_base_case_post`. + Preconditions: input chunks canonical (`natAbs ≤ 3328`), zetas + bounded by the table range (`natAbs ≤ 1664`), accumulator slice + length 16 (so the 8 pair indices 0..15 are all in range), AND + each accumulator lane bounded by `2^30` (wrap-protection for the + 8 `wrapping_add` calls — per-lane delta is ≤ 2^25 so output stays + well within I32 range; `2^30` headroom supports ~32 chained calls). + + POST adds a relative bound conjunct (`|r[k]| ≤ |out[k]| + 2^25`) + so callers (L6.3, then L7) can chain L2.8 invocations without + losing track of the accumulator's I32 bound. Mirrors the inverse-NTT bound-infra cascade (see + `[[project_inverse_ntt_bound_infra_asymmetry]]`). + + [F*-port: Vector.Portable.Ntt.ntt_multiply_binomials + ntt_multiply + (lines 432-584; Chunk.fst:587-625 commute lemma). F*-pre: + vector/portable/ntt.rs:339-345 — each accumulator lane within + i32 range.] -/ +@[spec] +theorem accumulating_ntt_multiply_fc + (lhs rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (out : Aeneas.Std.Slice Std.I32) + (zeta0 zeta1 zeta2 zeta3 : Std.I16) + (h_out_len : out.length = 16) + (h_lhs : ∀ j : Fin 16, (lhs.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_rhs : ∀ j : Fin 16, (rhs.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_zeta0 : zeta0.val.natAbs ≤ 1664) + (h_zeta1 : zeta1.val.natAbs ≤ 1664) + (h_zeta2 : zeta2.val.natAbs ≤ 1664) + (h_zeta3 : zeta3.val.natAbs ≤ 1664) + (h_out_bnd : ∀ k : Fin 16, (out.val[k.val]!).val.natAbs ≤ 2^30) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply + lhs rhs out zeta0 zeta1 zeta2 zeta3 + ⦃ ⇓ r => ⌜ r.length = 16 ∧ + (∀ k : Fin 16, (r.val[k.val]!).val.natAbs + ≤ (out.val[k.val]!).val.natAbs + 2^25) ∧ + ntt_multiply_base_case_post lhs rhs + zeta0 zeta1 zeta2 zeta3 out r ⌝ ⦄ := by + have h_zeta_within (z : Std.I16) (hz : z.val.natAbs ≤ 1664) : + z.val.natAbs ≤ 2^15 - 1 := by omega + have h_n0_val := L2_8c.wrapping_neg_val_eq zeta0 (h_zeta_within _ h_zeta0) + have h_n1_val := L2_8c.wrapping_neg_val_eq zeta1 (h_zeta_within _ h_zeta1) + have h_n2_val := L2_8c.wrapping_neg_val_eq zeta2 (h_zeta_within _ h_zeta2) + have h_n3_val := L2_8c.wrapping_neg_val_eq zeta3 (h_zeta_within _ h_zeta3) + set nzeta0 : Std.I16 := Aeneas.Std.I16.wrapping_sub (0#i16) zeta0 with hn0_def + set nzeta1 : Std.I16 := Aeneas.Std.I16.wrapping_sub (0#i16) zeta1 with hn1_def + set nzeta2 : Std.I16 := Aeneas.Std.I16.wrapping_sub (0#i16) zeta2 with hn2_def + set nzeta3 : Std.I16 := Aeneas.Std.I16.wrapping_sub (0#i16) zeta3 with hn3_def + have h_nz0_bnd : nzeta0.val.natAbs ≤ 1664 := by rw [h_n0_val]; omega + have h_nz1_bnd : nzeta1.val.natAbs ≤ 1664 := by rw [h_n1_val]; omega + have h_nz2_bnd : nzeta2.val.natAbs ≤ 1664 := by rw [h_n2_val]; omega + have h_nz3_bnd : nzeta3.val.natAbs ≤ 1664 := by rw [h_n3_val]; omega + have h_n0_fe : lift_fe_mont nzeta0 + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta0) := + L2_8c.lift_fe_mont_neg_pure_eq zeta0 nzeta0 (h_zeta_within _ h_zeta0) h_n0_val + have h_n1_fe : lift_fe_mont nzeta1 + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta1) := + L2_8c.lift_fe_mont_neg_pure_eq zeta1 nzeta1 (h_zeta_within _ h_zeta1) h_n1_val + have h_n2_fe : lift_fe_mont nzeta2 + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta2) := + L2_8c.lift_fe_mont_neg_pure_eq zeta2 nzeta2 (h_zeta_within _ h_zeta2) h_n2_val + have h_n3_fe : lift_fe_mont nzeta3 + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta3) := + L2_8c.lift_fe_mont_neg_pure_eq zeta3 nzeta3 (h_zeta_within _ h_zeta3) h_n3_val + have h_wn0 : core.num.I16.wrapping_neg zeta0 = .ok nzeta0 := + L2_8c.cm_wrapping_neg_i16_ok_eq zeta0 + have h_wn1 : core.num.I16.wrapping_neg zeta1 = .ok nzeta1 := + L2_8c.cm_wrapping_neg_i16_ok_eq zeta1 + have h_wn2 : core.num.I16.wrapping_neg zeta2 = .ok nzeta2 := + L2_8c.cm_wrapping_neg_i16_ok_eq zeta2 + have h_wn3 : core.num.I16.wrapping_neg zeta3 = .ok nzeta3 := + L2_8c.cm_wrapping_neg_i16_ok_eq zeta3 + have h_cz0 : libcrux_secrets.traits.Classify.Blanket.classify zeta0 = .ok zeta0 := + L2_8c.classify_ok_eq zeta0 + have h_cnz0 : libcrux_secrets.traits.Classify.Blanket.classify nzeta0 = .ok nzeta0 := + L2_8c.classify_ok_eq nzeta0 + have h_cz1 : libcrux_secrets.traits.Classify.Blanket.classify zeta1 = .ok zeta1 := + L2_8c.classify_ok_eq zeta1 + have h_cnz1 : libcrux_secrets.traits.Classify.Blanket.classify nzeta1 = .ok nzeta1 := + L2_8c.classify_ok_eq nzeta1 + have h_cz2 : libcrux_secrets.traits.Classify.Blanket.classify zeta2 = .ok zeta2 := + L2_8c.classify_ok_eq zeta2 + have h_cnz2 : libcrux_secrets.traits.Classify.Blanket.classify nzeta2 = .ok nzeta2 := + L2_8c.classify_ok_eq nzeta2 + have h_cz3 : libcrux_secrets.traits.Classify.Blanket.classify zeta3 = .ok zeta3 := + L2_8c.classify_ok_eq zeta3 + have h_cnz3 : libcrux_secrets.traits.Classify.Blanket.classify nzeta3 = .ok nzeta3 := + L2_8c.classify_ok_eq nzeta3 + have h_out_bnd_universal : ∀ k : Fin 16, (out.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := by + intro k; have := h_out_bnd k; omega + -- Call 0: pair 0 with zeta0 (touches lanes 0, 1). + obtain ⟨r0, h_r0_eq, h_r0_len, h_r0_unc, h_r0_bnd_e, h_r0_bnd_o, + h_r0_fe_e, h_r0_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fc lhs rhs zeta0 0#usize out + (by decide) h_out_len h_lhs h_rhs h_zeta0 h_out_bnd_universal) + have h_src_at_even : out.val[0]! = out.val[0]! := rfl + have h_src_at_odd : out.val[1]! = out.val[1]! := rfl + have h_r0_at_even : (r0.val[0]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (0#usize : Std.Usize).val : Nat) = 0 := by decide + have h_b := h_r0_bnd_e + rw [h_eq] at h_b + rw [h_src_at_even] at h_b + have h_out_le := h_out_bnd ⟨0, by decide⟩ + simp only at h_out_le; omega + have h_r0_at_odd : (r0.val[1]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (0#usize : Std.Usize).val + 1 : Nat) = 1 := by decide + have h_b := h_r0_bnd_o + rw [h_eq] at h_b + rw [h_src_at_odd] at h_b + have h_out_le := h_out_bnd ⟨1, by decide⟩ + simp only at h_out_le; omega + have h_r0_unc' : ∀ k : Nat, k < 16 → k ≠ 0 → k ≠ 1 → + r0.val[k]! = out.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (0#usize : Std.Usize).val : Nat) = 0 := by decide + have h_eq_o : (2 * (0#usize : Std.Usize).val + 1 : Nat) = 1 := by decide + apply h_r0_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_r0_bnd_universal : ∀ k : Fin 16, (r0.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step out r0 0 (by decide) h_out_bnd_universal + h_r0_unc' h_r0_at_even h_r0_at_odd + + -- Call 1: pair 1 with nzeta0 (touches lanes 2, 3). + obtain ⟨r1, h_r1_eq, h_r1_len, h_r1_unc, h_r1_bnd_e, h_r1_bnd_o, + h_r1_fe_e, h_r1_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fc lhs rhs nzeta0 1#usize r0 + (by decide) h_r0_len h_lhs h_rhs h_nz0_bnd h_r0_bnd_universal) + have h_src_at_even : r0.val[2]! = out.val[2]! := by + rw [h_r0_unc' 2 (by decide) (by decide) (by decide)] + have h_src_at_odd : r0.val[3]! = out.val[3]! := by + rw [h_r0_unc' 3 (by decide) (by decide) (by decide)] + have h_r1_at_even : (r1.val[2]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (1#usize : Std.Usize).val : Nat) = 2 := by decide + have h_b := h_r1_bnd_e + rw [h_eq] at h_b + rw [h_src_at_even] at h_b + have h_out_le := h_out_bnd ⟨2, by decide⟩ + simp only at h_out_le; omega + have h_r1_at_odd : (r1.val[3]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (1#usize : Std.Usize).val + 1 : Nat) = 3 := by decide + have h_b := h_r1_bnd_o + rw [h_eq] at h_b + rw [h_src_at_odd] at h_b + have h_out_le := h_out_bnd ⟨3, by decide⟩ + simp only at h_out_le; omega + have h_r1_unc' : ∀ k : Nat, k < 16 → k ≠ 2 → k ≠ 3 → + r1.val[k]! = r0.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (1#usize : Std.Usize).val : Nat) = 2 := by decide + have h_eq_o : (2 * (1#usize : Std.Usize).val + 1 : Nat) = 3 := by decide + apply h_r1_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_r1_bnd_universal : ∀ k : Fin 16, (r1.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r0 r1 1 (by decide) h_r0_bnd_universal + h_r1_unc' h_r1_at_even h_r1_at_odd + + -- Call 2: pair 2 with zeta1 (touches lanes 4, 5). + obtain ⟨r2, h_r2_eq, h_r2_len, h_r2_unc, h_r2_bnd_e, h_r2_bnd_o, + h_r2_fe_e, h_r2_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fc lhs rhs zeta1 2#usize r1 + (by decide) h_r1_len h_lhs h_rhs h_zeta1 h_r1_bnd_universal) + have h_src_at_even : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + have h_src_at_odd : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + have h_r2_at_even : (r2.val[4]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (2#usize : Std.Usize).val : Nat) = 4 := by decide + have h_b := h_r2_bnd_e + rw [h_eq] at h_b + rw [h_src_at_even] at h_b + have h_out_le := h_out_bnd ⟨4, by decide⟩ + simp only at h_out_le; omega + have h_r2_at_odd : (r2.val[5]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (2#usize : Std.Usize).val + 1 : Nat) = 5 := by decide + have h_b := h_r2_bnd_o + rw [h_eq] at h_b + rw [h_src_at_odd] at h_b + have h_out_le := h_out_bnd ⟨5, by decide⟩ + simp only at h_out_le; omega + have h_r2_unc' : ∀ k : Nat, k < 16 → k ≠ 4 → k ≠ 5 → + r2.val[k]! = r1.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (2#usize : Std.Usize).val : Nat) = 4 := by decide + have h_eq_o : (2 * (2#usize : Std.Usize).val + 1 : Nat) = 5 := by decide + apply h_r2_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_r2_bnd_universal : ∀ k : Fin 16, (r2.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r1 r2 2 (by decide) h_r1_bnd_universal + h_r2_unc' h_r2_at_even h_r2_at_odd + + -- Call 3: pair 3 with nzeta1 (touches lanes 6, 7). + obtain ⟨r3, h_r3_eq, h_r3_len, h_r3_unc, h_r3_bnd_e, h_r3_bnd_o, + h_r3_fe_e, h_r3_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fc lhs rhs nzeta1 3#usize r2 + (by decide) h_r2_len h_lhs h_rhs h_nz1_bnd h_r2_bnd_universal) + have h_src_at_even : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + have h_src_at_odd : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + have h_r3_at_even : (r3.val[6]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (3#usize : Std.Usize).val : Nat) = 6 := by decide + have h_b := h_r3_bnd_e + rw [h_eq] at h_b + rw [h_src_at_even] at h_b + have h_out_le := h_out_bnd ⟨6, by decide⟩ + simp only at h_out_le; omega + have h_r3_at_odd : (r3.val[7]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (3#usize : Std.Usize).val + 1 : Nat) = 7 := by decide + have h_b := h_r3_bnd_o + rw [h_eq] at h_b + rw [h_src_at_odd] at h_b + have h_out_le := h_out_bnd ⟨7, by decide⟩ + simp only at h_out_le; omega + have h_r3_unc' : ∀ k : Nat, k < 16 → k ≠ 6 → k ≠ 7 → + r3.val[k]! = r2.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (3#usize : Std.Usize).val : Nat) = 6 := by decide + have h_eq_o : (2 * (3#usize : Std.Usize).val + 1 : Nat) = 7 := by decide + apply h_r3_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_r3_bnd_universal : ∀ k : Fin 16, (r3.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r2 r3 3 (by decide) h_r2_bnd_universal + h_r3_unc' h_r3_at_even h_r3_at_odd + + -- Call 4: pair 4 with zeta2 (touches lanes 8, 9). + obtain ⟨r4, h_r4_eq, h_r4_len, h_r4_unc, h_r4_bnd_e, h_r4_bnd_o, + h_r4_fe_e, h_r4_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fc lhs rhs zeta2 4#usize r3 + (by decide) h_r3_len h_lhs h_rhs h_zeta2 h_r3_bnd_universal) + have h_src_at_even : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + have h_src_at_odd : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + have h_r4_at_even : (r4.val[8]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (4#usize : Std.Usize).val : Nat) = 8 := by decide + have h_b := h_r4_bnd_e + rw [h_eq] at h_b + rw [h_src_at_even] at h_b + have h_out_le := h_out_bnd ⟨8, by decide⟩ + simp only at h_out_le; omega + have h_r4_at_odd : (r4.val[9]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (4#usize : Std.Usize).val + 1 : Nat) = 9 := by decide + have h_b := h_r4_bnd_o + rw [h_eq] at h_b + rw [h_src_at_odd] at h_b + have h_out_le := h_out_bnd ⟨9, by decide⟩ + simp only at h_out_le; omega + have h_r4_unc' : ∀ k : Nat, k < 16 → k ≠ 8 → k ≠ 9 → + r4.val[k]! = r3.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (4#usize : Std.Usize).val : Nat) = 8 := by decide + have h_eq_o : (2 * (4#usize : Std.Usize).val + 1 : Nat) = 9 := by decide + apply h_r4_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_r4_bnd_universal : ∀ k : Fin 16, (r4.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r3 r4 4 (by decide) h_r3_bnd_universal + h_r4_unc' h_r4_at_even h_r4_at_odd + + -- Call 5: pair 5 with nzeta2 (touches lanes 10, 11). + obtain ⟨r5, h_r5_eq, h_r5_len, h_r5_unc, h_r5_bnd_e, h_r5_bnd_o, + h_r5_fe_e, h_r5_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fc lhs rhs nzeta2 5#usize r4 + (by decide) h_r4_len h_lhs h_rhs h_nz2_bnd h_r4_bnd_universal) + have h_src_at_even : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + have h_src_at_odd : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + have h_r5_at_even : (r5.val[10]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (5#usize : Std.Usize).val : Nat) = 10 := by decide + have h_b := h_r5_bnd_e + rw [h_eq] at h_b + rw [h_src_at_even] at h_b + have h_out_le := h_out_bnd ⟨10, by decide⟩ + simp only at h_out_le; omega + have h_r5_at_odd : (r5.val[11]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (5#usize : Std.Usize).val + 1 : Nat) = 11 := by decide + have h_b := h_r5_bnd_o + rw [h_eq] at h_b + rw [h_src_at_odd] at h_b + have h_out_le := h_out_bnd ⟨11, by decide⟩ + simp only at h_out_le; omega + have h_r5_unc' : ∀ k : Nat, k < 16 → k ≠ 10 → k ≠ 11 → + r5.val[k]! = r4.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (5#usize : Std.Usize).val : Nat) = 10 := by decide + have h_eq_o : (2 * (5#usize : Std.Usize).val + 1 : Nat) = 11 := by decide + apply h_r5_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_r5_bnd_universal : ∀ k : Fin 16, (r5.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r4 r5 5 (by decide) h_r4_bnd_universal + h_r5_unc' h_r5_at_even h_r5_at_odd + + -- Call 6: pair 6 with zeta3 (touches lanes 12, 13). + obtain ⟨r6, h_r6_eq, h_r6_len, h_r6_unc, h_r6_bnd_e, h_r6_bnd_o, + h_r6_fe_e, h_r6_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fc lhs rhs zeta3 6#usize r5 + (by decide) h_r5_len h_lhs h_rhs h_zeta3 h_r5_bnd_universal) + have h_src_at_even : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + have h_src_at_odd : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + have h_r6_at_even : (r6.val[12]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (6#usize : Std.Usize).val : Nat) = 12 := by decide + have h_b := h_r6_bnd_e + rw [h_eq] at h_b + rw [h_src_at_even] at h_b + have h_out_le := h_out_bnd ⟨12, by decide⟩ + simp only at h_out_le; omega + have h_r6_at_odd : (r6.val[13]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (6#usize : Std.Usize).val + 1 : Nat) = 13 := by decide + have h_b := h_r6_bnd_o + rw [h_eq] at h_b + rw [h_src_at_odd] at h_b + have h_out_le := h_out_bnd ⟨13, by decide⟩ + simp only at h_out_le; omega + have h_r6_unc' : ∀ k : Nat, k < 16 → k ≠ 12 → k ≠ 13 → + r6.val[k]! = r5.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (6#usize : Std.Usize).val : Nat) = 12 := by decide + have h_eq_o : (2 * (6#usize : Std.Usize).val + 1 : Nat) = 13 := by decide + apply h_r6_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_r6_bnd_universal : ∀ k : Fin 16, (r6.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r5 r6 6 (by decide) h_r5_bnd_universal + h_r6_unc' h_r6_at_even h_r6_at_odd + + -- Call 7: pair 7 with nzeta3 (touches lanes 14, 15). + obtain ⟨r7, h_r7_eq, h_r7_len, h_r7_unc, h_r7_bnd_e, h_r7_bnd_o, + h_r7_fe_e, h_r7_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fc lhs rhs nzeta3 7#usize r6 + (by decide) h_r6_len h_lhs h_rhs h_nz3_bnd h_r6_bnd_universal) + have h_src_at_even : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + have h_src_at_odd : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + have h_r7_at_even : (r7.val[14]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (7#usize : Std.Usize).val : Nat) = 14 := by decide + have h_b := h_r7_bnd_e + rw [h_eq] at h_b + rw [h_src_at_even] at h_b + have h_out_le := h_out_bnd ⟨14, by decide⟩ + simp only at h_out_le; omega + have h_r7_at_odd : (r7.val[15]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (7#usize : Std.Usize).val + 1 : Nat) = 15 := by decide + have h_b := h_r7_bnd_o + rw [h_eq] at h_b + rw [h_src_at_odd] at h_b + have h_out_le := h_out_bnd ⟨15, by decide⟩ + simp only at h_out_le; omega + have h_r7_unc' : ∀ k : Nat, k < 16 → k ≠ 14 → k ≠ 15 → + r7.val[k]! = r6.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (7#usize : Std.Usize).val : Nat) = 14 := by decide + have h_eq_o : (2 * (7#usize : Std.Usize).val + 1 : Nat) = 15 := by decide + apply h_r7_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_r7_bnd_universal : ∀ k : Fin 16, (r7.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r6 r7 7 (by decide) h_r6_bnd_universal + h_r7_unc' h_r7_at_even h_r7_at_odd + + -- Compose the monadic body. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply + lhs rhs out zeta0 zeta1 zeta2 zeta3 = .ok r7 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply + simp only [h_wn0, h_wn1, h_wn2, h_wn3, + h_cz0, h_cnz0, h_cz1, h_cnz1, h_cz2, h_cnz2, h_cz3, h_cnz3, + h_r0_eq, h_r1_eq, h_r2_eq, h_r3_eq, + h_r4_eq, h_r5_eq, h_r6_eq, h_r7_eq, + Aeneas.Std.bind_tc_ok] + apply triple_of_ok_fc h_body + -- POST: 3-conjunct. + refine ⟨h_r7_len, ?_, ?_⟩ + · -- Relative bound: ∀ k, r7.val[k]!.natAbs ≤ out.val[k]!.natAbs + 2^25. + -- Each lane is touched at most once → bound is +2^25 above out. + -- Build via unc-chains + per-pair touched bounds. + -- Strategy: 16-way case split. + intro k + rcases k with ⟨k, hk⟩ + -- Walk back to find which call (if any) touched this lane. + interval_cases k + -- Lane 0: touched by call 0 (i=0, even). + · have h_r7_at_0 : r7.val[0]! = r0.val[0]! := by + rw [h_r7_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r7_at_0] + have h_eq : (2 * (0#usize : Std.Usize).val : Nat) = 0 := by decide + have h_b := h_r0_bnd_e + rw [h_eq] at h_b + -- Source for call 0 is `out`, so h_src_at_even = rfl, no rewrite needed. + exact h_b + -- Lane 1: touched by call 0 (i=0, odd). + · have h_r7_at_1 : r7.val[1]! = r0.val[1]! := by + rw [h_r7_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r7_at_1] + have h_eq : (2 * (0#usize : Std.Usize).val + 1 : Nat) = 1 := by decide + have h_b := h_r0_bnd_o + rw [h_eq] at h_b + exact h_b + -- Lane 2: touched by call 1 (i=1, even). Source for call 1 was r0, but r0.val[2]! = out.val[2]!. + · have h_r7_at_2 : r7.val[2]! = r1.val[2]! := by + rw [h_r7_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r7_at_2] + have h_eq : (2 * (1#usize : Std.Usize).val : Nat) = 2 := by decide + have h_b := h_r1_bnd_e + rw [h_eq] at h_b + -- h_b : r1.val[2]!.natAbs ≤ r0.val[2]!.natAbs + 2^25. + -- Need: r1.val[2]!.natAbs ≤ out.val[2]!.natAbs + 2^25. + -- r0.val[2]! = out.val[2]! (lane 2 fresh for call 0). + have h_r0_at_2 : r0.val[2]! = out.val[2]! := by + rw [h_r0_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r0_at_2] at h_b + exact h_b + -- Lane 3: touched by call 1 (i=1, odd). + · have h_r7_at_3 : r7.val[3]! = r1.val[3]! := by + rw [h_r7_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r7_at_3] + have h_eq : (2 * (1#usize : Std.Usize).val + 1 : Nat) = 3 := by decide + have h_b := h_r1_bnd_o + rw [h_eq] at h_b + have h_r0_at_3 : r0.val[3]! = out.val[3]! := by + rw [h_r0_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r0_at_3] at h_b + exact h_b + -- Lane 4: touched by call 2 (i=2, even). + · have h_r7_at_4 : r7.val[4]! = r2.val[4]! := by + rw [h_r7_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r7_at_4] + have h_eq : (2 * (2#usize : Std.Usize).val : Nat) = 4 := by decide + have h_b := h_r2_bnd_e + rw [h_eq] at h_b + have h_r1_at_4 : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r1_at_4] at h_b + exact h_b + -- Lane 5: touched by call 2 (i=2, odd). + · have h_r7_at_5 : r7.val[5]! = r2.val[5]! := by + rw [h_r7_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r7_at_5] + have h_eq : (2 * (2#usize : Std.Usize).val + 1 : Nat) = 5 := by decide + have h_b := h_r2_bnd_o + rw [h_eq] at h_b + have h_r1_at_5 : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r1_at_5] at h_b + exact h_b + -- Lane 6: touched by call 3 (i=3, even). + · have h_r7_at_6 : r7.val[6]! = r3.val[6]! := by + rw [h_r7_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r7_at_6] + have h_eq : (2 * (3#usize : Std.Usize).val : Nat) = 6 := by decide + have h_b := h_r3_bnd_e + rw [h_eq] at h_b + have h_r2_at_6 : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r2_at_6] at h_b + exact h_b + -- Lane 7: touched by call 3 (i=3, odd). + · have h_r7_at_7 : r7.val[7]! = r3.val[7]! := by + rw [h_r7_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r7_at_7] + have h_eq : (2 * (3#usize : Std.Usize).val + 1 : Nat) = 7 := by decide + have h_b := h_r3_bnd_o + rw [h_eq] at h_b + have h_r2_at_7 : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r2_at_7] at h_b + exact h_b + -- Lane 8. + · have h_r7_at_8 : r7.val[8]! = r4.val[8]! := by + rw [h_r7_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r7_at_8] + have h_eq : (2 * (4#usize : Std.Usize).val : Nat) = 8 := by decide + have h_b := h_r4_bnd_e + rw [h_eq] at h_b + have h_r3_at_8 : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r3_at_8] at h_b + exact h_b + -- Lane 9. + · have h_r7_at_9 : r7.val[9]! = r4.val[9]! := by + rw [h_r7_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r7_at_9] + have h_eq : (2 * (4#usize : Std.Usize).val + 1 : Nat) = 9 := by decide + have h_b := h_r4_bnd_o + rw [h_eq] at h_b + have h_r3_at_9 : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r3_at_9] at h_b + exact h_b + -- Lane 10. + · have h_r7_at_10 : r7.val[10]! = r5.val[10]! := by + rw [h_r7_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r7_at_10] + have h_eq : (2 * (5#usize : Std.Usize).val : Nat) = 10 := by decide + have h_b := h_r5_bnd_e + rw [h_eq] at h_b + have h_r4_at_10 : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r4_at_10] at h_b + exact h_b + -- Lane 11. + · have h_r7_at_11 : r7.val[11]! = r5.val[11]! := by + rw [h_r7_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r7_at_11] + have h_eq : (2 * (5#usize : Std.Usize).val + 1 : Nat) = 11 := by decide + have h_b := h_r5_bnd_o + rw [h_eq] at h_b + have h_r4_at_11 : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r4_at_11] at h_b + exact h_b + -- Lane 12. + · have h_r7_at_12 : r7.val[12]! = r6.val[12]! := by + rw [h_r7_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r7_at_12] + have h_eq : (2 * (6#usize : Std.Usize).val : Nat) = 12 := by decide + have h_b := h_r6_bnd_e + rw [h_eq] at h_b + have h_r5_at_12 : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r5_at_12] at h_b + exact h_b + -- Lane 13. + · have h_r7_at_13 : r7.val[13]! = r6.val[13]! := by + rw [h_r7_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r7_at_13] + have h_eq : (2 * (6#usize : Std.Usize).val + 1 : Nat) = 13 := by decide + have h_b := h_r6_bnd_o + rw [h_eq] at h_b + have h_r5_at_13 : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r5_at_13] at h_b + exact h_b + -- Lane 14: touched by call 7 (i=7, even). + · have h_eq : (2 * (7#usize : Std.Usize).val : Nat) = 14 := by decide + have h_b := h_r7_bnd_e + rw [h_eq] at h_b + have h_r6_at_14 : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r6_at_14] at h_b + exact h_b + -- Lane 15. + · have h_eq : (2 * (7#usize : Std.Usize).val + 1 : Nat) = 15 := by decide + have h_b := h_r7_bnd_o + rw [h_eq] at h_b + have h_r6_at_15 : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r6_at_15] at h_b + exact h_b + · -- ntt_multiply_base_case_post: per-lane FE equation. + unfold ntt_multiply_base_case_post ntt_multiply_base_case_alg + apply Subtype.ext + have h_lhs_val : (Spec.chunk_reducing_from_i32_array_pure r7).val + = (List.range 16).map (fun i => Spec.mont_reduce_pure (lift_fe_int (r7.val[i]!).val)) := by + unfold Spec.chunk_reducing_from_i32_array_pure; rfl + have h_rhs_val : (Spec.chunk_add_pure + (Spec.chunk_reducing_from_i32_array_pure out) + (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3))).val + = (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.chunk_reducing_from_i32_array_pure out).val[i]!) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[i]!)) := by + unfold Spec.chunk_add_pure; rfl + rw [h_lhs_val, h_rhs_val] + apply List.ext_getElem + · simp + · intro k hk1 hk2 + have hk : k < 16 := by simp at hk1; exact hk1 + rw [List.getElem_map, List.getElem_map, List.getElem_range] + interval_cases k + · -- Lane 0: touched by call 0 (zeta0, even). + have h_r7_at_lane : r7.val[0]! = r0.val[0]! := by + rw [h_r7_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_fe := h_r0_fe_e + simp only [ + show (2 * (0#usize : Std.Usize).val : Nat) = 0 from by decide] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[0]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[0]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[0]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[0]!) + ((lift_chunk_mont rhs).val[0]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[1]!) + ((lift_chunk_mont rhs).val[1]!)) + (lift_fe_mont zeta0)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_0 : (lift_chunk_mont lhs).val[0]! + = lift_fe_mont (lhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_lhs_1 : (lift_chunk_mont lhs).val[1]! + = lift_fe_mont (lhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 1 (by rw [h_l]; decide)] + have h_lcm_rhs_0 : (lift_chunk_mont rhs).val[0]! + = lift_fe_mont (rhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_rhs_1 : (lift_chunk_mont rhs).val[1]! + = lift_fe_mont (rhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 1 (by rw [h_l]; decide)] + rw [h_lcm_lhs_0, h_lcm_lhs_1, h_lcm_rhs_0, h_lcm_rhs_1] + · -- Lane 1: touched by call 0 (zeta0, odd). + have h_r7_at_lane : r7.val[1]! = r0.val[1]! := by + rw [h_r7_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_fe := h_r0_fe_o + simp only [ + show (2 * (0#usize : Std.Usize).val : Nat) = 0 from by decide] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[1]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[1]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[1]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[0]!) + ((lift_chunk_mont rhs).val[1]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[1]!) + ((lift_chunk_mont rhs).val[0]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_0 : (lift_chunk_mont lhs).val[0]! + = lift_fe_mont (lhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_lhs_1 : (lift_chunk_mont lhs).val[1]! + = lift_fe_mont (lhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 1 (by rw [h_l]; decide)] + have h_lcm_rhs_0 : (lift_chunk_mont rhs).val[0]! + = lift_fe_mont (rhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_rhs_1 : (lift_chunk_mont rhs).val[1]! + = lift_fe_mont (rhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 1 (by rw [h_l]; decide)] + rw [h_lcm_lhs_0, h_lcm_lhs_1, h_lcm_rhs_0, h_lcm_rhs_1] + · -- Lane 2: touched by call 1 (nzeta0, even). + have h_r7_at_lane : r7.val[2]! = r1.val[2]! := by + rw [h_r7_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r0.val[2]! = out.val[2]! := by + rw [h_r0_unc' 2 (by decide) (by decide) (by decide)] + have h_src_at_odd : r0.val[3]! = out.val[3]! := by + rw [h_r0_unc' 3 (by decide) (by decide) (by decide)] + have h_fe := h_r1_fe_e + simp only [ + show (2 * (1#usize : Std.Usize).val : Nat) = 2 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n0_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[2]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[2]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[2]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[2]!) + ((lift_chunk_mont rhs).val[2]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[3]!) + ((lift_chunk_mont rhs).val[3]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta0))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_2 : (lift_chunk_mont lhs).val[2]! + = lift_fe_mont (lhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_lhs_3 : (lift_chunk_mont lhs).val[3]! + = lift_fe_mont (lhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 3 (by rw [h_l]; decide)] + have h_lcm_rhs_2 : (lift_chunk_mont rhs).val[2]! + = lift_fe_mont (rhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_rhs_3 : (lift_chunk_mont rhs).val[3]! + = lift_fe_mont (rhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 3 (by rw [h_l]; decide)] + rw [h_lcm_lhs_2, h_lcm_lhs_3, h_lcm_rhs_2, h_lcm_rhs_3] + · -- Lane 3: touched by call 1 (nzeta0, odd). + have h_r7_at_lane : r7.val[3]! = r1.val[3]! := by + rw [h_r7_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r0.val[2]! = out.val[2]! := by + rw [h_r0_unc' 2 (by decide) (by decide) (by decide)] + have h_src_at_odd : r0.val[3]! = out.val[3]! := by + rw [h_r0_unc' 3 (by decide) (by decide) (by decide)] + have h_fe := h_r1_fe_o + simp only [ + show (2 * (1#usize : Std.Usize).val : Nat) = 2 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[3]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[3]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[3]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[2]!) + ((lift_chunk_mont rhs).val[3]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[3]!) + ((lift_chunk_mont rhs).val[2]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_2 : (lift_chunk_mont lhs).val[2]! + = lift_fe_mont (lhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_lhs_3 : (lift_chunk_mont lhs).val[3]! + = lift_fe_mont (lhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 3 (by rw [h_l]; decide)] + have h_lcm_rhs_2 : (lift_chunk_mont rhs).val[2]! + = lift_fe_mont (rhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_rhs_3 : (lift_chunk_mont rhs).val[3]! + = lift_fe_mont (rhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 3 (by rw [h_l]; decide)] + rw [h_lcm_lhs_2, h_lcm_lhs_3, h_lcm_rhs_2, h_lcm_rhs_3] + · -- Lane 4: touched by call 2 (zeta1, even). + have h_r7_at_lane : r7.val[4]! = r2.val[4]! := by + rw [h_r7_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + have h_src_at_odd : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + have h_fe := h_r2_fe_e + simp only [ + show (2 * (2#usize : Std.Usize).val : Nat) = 4 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[4]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[4]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[4]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[4]!) + ((lift_chunk_mont rhs).val[4]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[5]!) + ((lift_chunk_mont rhs).val[5]!)) + (lift_fe_mont zeta1)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_4 : (lift_chunk_mont lhs).val[4]! + = lift_fe_mont (lhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_lhs_5 : (lift_chunk_mont lhs).val[5]! + = lift_fe_mont (lhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 5 (by rw [h_l]; decide)] + have h_lcm_rhs_4 : (lift_chunk_mont rhs).val[4]! + = lift_fe_mont (rhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_rhs_5 : (lift_chunk_mont rhs).val[5]! + = lift_fe_mont (rhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 5 (by rw [h_l]; decide)] + rw [h_lcm_lhs_4, h_lcm_lhs_5, h_lcm_rhs_4, h_lcm_rhs_5] + · -- Lane 5: touched by call 2 (zeta1, odd). + have h_r7_at_lane : r7.val[5]! = r2.val[5]! := by + rw [h_r7_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + have h_src_at_odd : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + have h_fe := h_r2_fe_o + simp only [ + show (2 * (2#usize : Std.Usize).val : Nat) = 4 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[5]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[5]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[5]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[4]!) + ((lift_chunk_mont rhs).val[5]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[5]!) + ((lift_chunk_mont rhs).val[4]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_4 : (lift_chunk_mont lhs).val[4]! + = lift_fe_mont (lhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_lhs_5 : (lift_chunk_mont lhs).val[5]! + = lift_fe_mont (lhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 5 (by rw [h_l]; decide)] + have h_lcm_rhs_4 : (lift_chunk_mont rhs).val[4]! + = lift_fe_mont (rhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_rhs_5 : (lift_chunk_mont rhs).val[5]! + = lift_fe_mont (rhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 5 (by rw [h_l]; decide)] + rw [h_lcm_lhs_4, h_lcm_lhs_5, h_lcm_rhs_4, h_lcm_rhs_5] + · -- Lane 6: touched by call 3 (nzeta1, even). + have h_r7_at_lane : r7.val[6]! = r3.val[6]! := by + rw [h_r7_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + have h_src_at_odd : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + have h_fe := h_r3_fe_e + simp only [ + show (2 * (3#usize : Std.Usize).val : Nat) = 6 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n1_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[6]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[6]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[6]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[6]!) + ((lift_chunk_mont rhs).val[6]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[7]!) + ((lift_chunk_mont rhs).val[7]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta1))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_6 : (lift_chunk_mont lhs).val[6]! + = lift_fe_mont (lhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_lhs_7 : (lift_chunk_mont lhs).val[7]! + = lift_fe_mont (lhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 7 (by rw [h_l]; decide)] + have h_lcm_rhs_6 : (lift_chunk_mont rhs).val[6]! + = lift_fe_mont (rhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_rhs_7 : (lift_chunk_mont rhs).val[7]! + = lift_fe_mont (rhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 7 (by rw [h_l]; decide)] + rw [h_lcm_lhs_6, h_lcm_lhs_7, h_lcm_rhs_6, h_lcm_rhs_7] + · -- Lane 7: touched by call 3 (nzeta1, odd). + have h_r7_at_lane : r7.val[7]! = r3.val[7]! := by + rw [h_r7_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + have h_src_at_odd : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + have h_fe := h_r3_fe_o + simp only [ + show (2 * (3#usize : Std.Usize).val : Nat) = 6 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[7]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[7]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[7]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[6]!) + ((lift_chunk_mont rhs).val[7]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[7]!) + ((lift_chunk_mont rhs).val[6]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_6 : (lift_chunk_mont lhs).val[6]! + = lift_fe_mont (lhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_lhs_7 : (lift_chunk_mont lhs).val[7]! + = lift_fe_mont (lhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 7 (by rw [h_l]; decide)] + have h_lcm_rhs_6 : (lift_chunk_mont rhs).val[6]! + = lift_fe_mont (rhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_rhs_7 : (lift_chunk_mont rhs).val[7]! + = lift_fe_mont (rhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 7 (by rw [h_l]; decide)] + rw [h_lcm_lhs_6, h_lcm_lhs_7, h_lcm_rhs_6, h_lcm_rhs_7] + · -- Lane 8: touched by call 4 (zeta2, even). + have h_r7_at_lane : r7.val[8]! = r4.val[8]! := by + rw [h_r7_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + have h_src_at_odd : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + have h_fe := h_r4_fe_e + simp only [ + show (2 * (4#usize : Std.Usize).val : Nat) = 8 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[8]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[8]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[8]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[8]!) + ((lift_chunk_mont rhs).val[8]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[9]!) + ((lift_chunk_mont rhs).val[9]!)) + (lift_fe_mont zeta2)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_8 : (lift_chunk_mont lhs).val[8]! + = lift_fe_mont (lhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_lhs_9 : (lift_chunk_mont lhs).val[9]! + = lift_fe_mont (lhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 9 (by rw [h_l]; decide)] + have h_lcm_rhs_8 : (lift_chunk_mont rhs).val[8]! + = lift_fe_mont (rhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_rhs_9 : (lift_chunk_mont rhs).val[9]! + = lift_fe_mont (rhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 9 (by rw [h_l]; decide)] + rw [h_lcm_lhs_8, h_lcm_lhs_9, h_lcm_rhs_8, h_lcm_rhs_9] + · -- Lane 9: touched by call 4 (zeta2, odd). + have h_r7_at_lane : r7.val[9]! = r4.val[9]! := by + rw [h_r7_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + have h_src_at_odd : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + have h_fe := h_r4_fe_o + simp only [ + show (2 * (4#usize : Std.Usize).val : Nat) = 8 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[9]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[9]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[9]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[8]!) + ((lift_chunk_mont rhs).val[9]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[9]!) + ((lift_chunk_mont rhs).val[8]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_8 : (lift_chunk_mont lhs).val[8]! + = lift_fe_mont (lhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_lhs_9 : (lift_chunk_mont lhs).val[9]! + = lift_fe_mont (lhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 9 (by rw [h_l]; decide)] + have h_lcm_rhs_8 : (lift_chunk_mont rhs).val[8]! + = lift_fe_mont (rhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_rhs_9 : (lift_chunk_mont rhs).val[9]! + = lift_fe_mont (rhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 9 (by rw [h_l]; decide)] + rw [h_lcm_lhs_8, h_lcm_lhs_9, h_lcm_rhs_8, h_lcm_rhs_9] + · -- Lane 10: touched by call 5 (nzeta2, even). + have h_r7_at_lane : r7.val[10]! = r5.val[10]! := by + rw [h_r7_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + have h_src_at_odd : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + have h_fe := h_r5_fe_e + simp only [ + show (2 * (5#usize : Std.Usize).val : Nat) = 10 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n2_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[10]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[10]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[10]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[10]!) + ((lift_chunk_mont rhs).val[10]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[11]!) + ((lift_chunk_mont rhs).val[11]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta2))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_10 : (lift_chunk_mont lhs).val[10]! + = lift_fe_mont (lhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_lhs_11 : (lift_chunk_mont lhs).val[11]! + = lift_fe_mont (lhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 11 (by rw [h_l]; decide)] + have h_lcm_rhs_10 : (lift_chunk_mont rhs).val[10]! + = lift_fe_mont (rhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_rhs_11 : (lift_chunk_mont rhs).val[11]! + = lift_fe_mont (rhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 11 (by rw [h_l]; decide)] + rw [h_lcm_lhs_10, h_lcm_lhs_11, h_lcm_rhs_10, h_lcm_rhs_11] + · -- Lane 11: touched by call 5 (nzeta2, odd). + have h_r7_at_lane : r7.val[11]! = r5.val[11]! := by + rw [h_r7_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + have h_src_at_odd : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + have h_fe := h_r5_fe_o + simp only [ + show (2 * (5#usize : Std.Usize).val : Nat) = 10 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[11]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[11]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[11]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[10]!) + ((lift_chunk_mont rhs).val[11]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[11]!) + ((lift_chunk_mont rhs).val[10]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_10 : (lift_chunk_mont lhs).val[10]! + = lift_fe_mont (lhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_lhs_11 : (lift_chunk_mont lhs).val[11]! + = lift_fe_mont (lhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 11 (by rw [h_l]; decide)] + have h_lcm_rhs_10 : (lift_chunk_mont rhs).val[10]! + = lift_fe_mont (rhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_rhs_11 : (lift_chunk_mont rhs).val[11]! + = lift_fe_mont (rhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 11 (by rw [h_l]; decide)] + rw [h_lcm_lhs_10, h_lcm_lhs_11, h_lcm_rhs_10, h_lcm_rhs_11] + · -- Lane 12: touched by call 6 (zeta3, even). + have h_r7_at_lane : r7.val[12]! = r6.val[12]! := by + rw [h_r7_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + have h_src_at_odd : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + have h_fe := h_r6_fe_e + simp only [ + show (2 * (6#usize : Std.Usize).val : Nat) = 12 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[12]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[12]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[12]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[12]!) + ((lift_chunk_mont rhs).val[12]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[13]!) + ((lift_chunk_mont rhs).val[13]!)) + (lift_fe_mont zeta3)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_12 : (lift_chunk_mont lhs).val[12]! + = lift_fe_mont (lhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_lhs_13 : (lift_chunk_mont lhs).val[13]! + = lift_fe_mont (lhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 13 (by rw [h_l]; decide)] + have h_lcm_rhs_12 : (lift_chunk_mont rhs).val[12]! + = lift_fe_mont (rhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_rhs_13 : (lift_chunk_mont rhs).val[13]! + = lift_fe_mont (rhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 13 (by rw [h_l]; decide)] + rw [h_lcm_lhs_12, h_lcm_lhs_13, h_lcm_rhs_12, h_lcm_rhs_13] + · -- Lane 13: touched by call 6 (zeta3, odd). + have h_r7_at_lane : r7.val[13]! = r6.val[13]! := by + rw [h_r7_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + have h_src_at_odd : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + have h_fe := h_r6_fe_o + simp only [ + show (2 * (6#usize : Std.Usize).val : Nat) = 12 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[13]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[13]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[13]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[12]!) + ((lift_chunk_mont rhs).val[13]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[13]!) + ((lift_chunk_mont rhs).val[12]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_12 : (lift_chunk_mont lhs).val[12]! + = lift_fe_mont (lhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_lhs_13 : (lift_chunk_mont lhs).val[13]! + = lift_fe_mont (lhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 13 (by rw [h_l]; decide)] + have h_lcm_rhs_12 : (lift_chunk_mont rhs).val[12]! + = lift_fe_mont (rhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_rhs_13 : (lift_chunk_mont rhs).val[13]! + = lift_fe_mont (rhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 13 (by rw [h_l]; decide)] + rw [h_lcm_lhs_12, h_lcm_lhs_13, h_lcm_rhs_12, h_lcm_rhs_13] + · -- Lane 14: touched by call 7 (nzeta3, even). + have h_src_at_even : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + have h_src_at_odd : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + have h_fe := h_r7_fe_e + simp only [ + show (2 * (7#usize : Std.Usize).val : Nat) = 14 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n3_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[14]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[14]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[14]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[14]!) + ((lift_chunk_mont rhs).val[14]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[15]!) + ((lift_chunk_mont rhs).val[15]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta3))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_14 : (lift_chunk_mont lhs).val[14]! + = lift_fe_mont (lhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_lhs_15 : (lift_chunk_mont lhs).val[15]! + = lift_fe_mont (lhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 15 (by rw [h_l]; decide)] + have h_lcm_rhs_14 : (lift_chunk_mont rhs).val[14]! + = lift_fe_mont (rhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_rhs_15 : (lift_chunk_mont rhs).val[15]! + = lift_fe_mont (rhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 15 (by rw [h_l]; decide)] + rw [h_lcm_lhs_14, h_lcm_lhs_15, h_lcm_rhs_14, h_lcm_rhs_15] + · -- Lane 15: touched by call 7 (nzeta3, odd). + have h_src_at_even : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + have h_src_at_odd : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + have h_fe := h_r7_fe_o + simp only [ + show (2 * (7#usize : Std.Usize).val : Nat) = 14 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[15]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[15]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[15]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[14]!) + ((lift_chunk_mont rhs).val[15]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[15]!) + ((lift_chunk_mont rhs).val[14]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_14 : (lift_chunk_mont lhs).val[14]! + = lift_fe_mont (lhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_lhs_15 : (lift_chunk_mont lhs).val[15]! + = lift_fe_mont (lhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 15 (by rw [h_l]; decide)] + have h_lcm_rhs_14 : (lift_chunk_mont rhs).val[14]! + = lift_fe_mont (rhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_rhs_15 : (lift_chunk_mont rhs).val[15]! + = lift_fe_mont (rhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 15 (by rw [h_l]; decide)] + rw [h_lcm_lhs_14, h_lcm_lhs_15, h_lcm_rhs_14, h_lcm_rhs_15] + + +/-! ## §L2.8d — Cache-variant Triple statements (fill_cache + use_cache). + + The impl provides two siblings of `accumulating_ntt_multiply` that + factor out the per-pair Mont-reduced `b·zeta` products into a + 16-lane cache vector: + • `_fill_cache`: behaves identically to `accumulating_ntt_multiply` + on the accumulator slice AND writes `mont_reduce(b[2i+1]·zeta_i)` + into `cache[i]` for each pair i ∈ Fin 8 (cache slots 8..15 + untouched). + • `_use_cache`: skips the per-pair Mont reduction by reading the + cached I16 directly. Requires a cache pre-condition asserting + each cache slot equals the Mont-reduced `b·zeta` product for + the corresponding effective zeta. + + Composition pattern (matrix-row reuse): `_fill_cache(A, B, _, _, zetas)` + sets the cache, then multiple `_use_cache(A', B, _, cache)` calls reuse + it with different first operands and the same `B`/zeta structure. -/ + +/-- Effective per-pair zeta for the 8 binomial calls in a chunk: pair + `2j` uses `zetaJ`, pair `2j+1` uses `neg_pure zetaJ` (the bit-side + `wrapping_neg` projected through `lift_fe_mont`). Used to express + the cache POST predicate at the FE-projection level. -/ +noncomputable def Spec.effective_zeta_fe + (i : Fin 8) + (z0 z1 z2 z3 : hacspec_ml_kem.parameters.FieldElement) : + hacspec_ml_kem.parameters.FieldElement := + if i.val = 0 then z0 + else if i.val = 1 then libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure z0 + else if i.val = 2 then z1 + else if i.val = 3 then libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure z1 + else if i.val = 4 then z2 + else if i.val = 5 then libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure z2 + else if i.val = 6 then z3 + else libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure z3 + +/-- Cache POST predicate shared between `_fill_cache` (as output cache + POST) and `_use_cache` (as input cache PRE). For each pair `i ∈ Fin 8`: + • `cache[i]` is canonical (`natAbs ≤ 3328`) — Mont reduction always + produces values in this range; and + • `lift_fe_mont cache[i] = mul_pure (lift_fe_mont rhs[2i+1]) + (effective_zeta_fe i z0 z1 z2 z3)` + — i.e., the cache slot at pair `i` represents the FE product + of `rhs`'s odd-lane operand and the pair's effective zeta. -/ +noncomputable def Spec.ntt_multiply_cache_post + (rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 zeta2 zeta3 : Std.I16) + (cache : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Prop := + ∀ i : Fin 8, + (cache.elements.val[i.val]!).val.natAbs ≤ 3328 + ∧ lift_fe_mont (cache.elements.val[i.val]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * i.val + 1]!)) + (Spec.effective_zeta_fe i + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) + +/-! ### L2.8d — helper lemmas (sibling-adapt of L2.8c). + + Three new helpers extend the L2_8c.* infrastructure for the + cache-variant Triples: + • `L2_8d.lift_fe_mont_of_mont_reduce_modq` — translates the modq + relation from `montgomery_reduce_element_spec` (`r * 2^16 ≡ x * y`) + to a `lift_fe_mont r = mul_pure (lift_fe_mont x) (lift_fe_mont y)` + equation (the fill-cache POST cache equation). + • `L2_8d.mont_reduce_even_fe_eq_cache` and + `L2_8d.mont_reduce_odd_fe_eq_cache` — Mont-domain FE equation + builders for the use_cache per-pair Triple. Differ from + `L2_8c.mont_reduce_{even,odd}_fe_eq` by carrying a symbolic + cache-lane I16 in the RHS instead of an explicit `bj * zeta` + product. -/ + +/-- Mont-reduced product → FE-projection bridge. If `r` is the + Montgomery reduction of `x.val * y.val` (so `r.val * 2^16 ≡ + x.val * y.val (mod q)`), then `lift_fe_mont r = mul_pure + (lift_fe_mont x) (lift_fe_mont y)`. Used by `_fill_cache` per-pair + Triple to discharge the cache POST equation. -/ +theorem L2_8d.lift_fe_mont_of_mont_reduce_modq + (r x y : Std.I16) + (h_canon : r.val.natAbs ≤ 3328) + (h_zmod : ((r.val : Int) : ZMod 3329) * (2^16 : Int) + = ((x.val : Int) : ZMod 3329) * ((y.val : Int) : ZMod 3329)) : + lift_fe_mont r + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont x) (lift_fe_mont y) := by + -- LHS: feOfZMod ((r.val : ZMod q) * 169). + -- Set up s := mul_pure (lift_fe_mont x) (lift_fe_mont y); s is canonical, so + -- the goal collapses (after round-trip) to a ZMod q equation. + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont x) (lift_fe_mont y) with hs_def + -- Express s.val.val < 3329 via Canonical_mul_pure. + have h_canon_s : s.val.val < 3329 := by + have h_cm := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure + (lift_fe_mont x) (lift_fe_mont y) + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cm + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cm + exact h_cm + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon_s + -- zmodOfFE s = (x.val : ZMod q) * 169 * ((y.val : ZMod q) * 169). + have h_zmod_s : zmodOfFE s + = ((x.val : Int) : ZMod 3329) * 169 * (((y.val : Int) : ZMod 3329) * 169) := by + rw [hs_def, L2_8c.zmodOfFE_mul_pure, L2_8c.zmodOfFE_lift_fe_mont, + L2_8c.zmodOfFE_lift_fe_mont] + -- Now reduce LHS to feOfZMod ((r.val : ZMod q) * 169). + have h_lhs : lift_fe_mont r = feOfZMod (((r.val : Int) : ZMod 3329) * 169) := by + unfold lift_fe_mont i16_to_spec_fe_mont; rfl + rw [h_lhs, ← h_round_trip, h_zmod_s] + congr 1 + -- Goal: (r.val : Int : ZMod q) * 169 = (x.val : ZMod q) * 169 * (y.val : ZMod q) * 169. + -- From h_zmod: r * 2^16 = x * y in ZMod q. + -- Multiply by 169² on both sides: r * (2^16 * 169) * 169 = x * y * 169 * 169. + -- 2^16 * 169 ≡ 1 (since 2^16 ≡ 2285 and 2285 * 169 = 1 in ZMod q). + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + have h_2_16 : ((2^16 : Int) : ZMod 3329) = 2285 := by decide + -- Multiply h_zmod by 169 on both sides. + have h_mul_169 : + ((r.val : Int) : ZMod 3329) * (2^16 : Int) * 169 + = ((x.val : Int) : ZMod 3329) * ((y.val : Int) : ZMod 3329) * 169 := by + rw [h_zmod] + rw [h_2_16] at h_mul_169 + -- LHS rewrite: r * 169 = r * 2285 * 169 * 169 (since 2285 * 169 = 1). + have h_lhs : + ((r.val : Int) : ZMod 3329) * 169 + = ((r.val : Int) : ZMod 3329) * 2285 * 169 * 169 := by + have : ((r.val : Int) : ZMod 3329) * 169 = ((r.val : Int) : ZMod 3329) * (2285 * 169) * 169 := by + rw [h_inv]; ring + rw [this]; ring + rw [h_lhs] + -- Now: r * 2285 * 169 * 169 = (x * y) * 169 * 169 from h_mul_169 (multiplied by 169 again). + have h_mul_169_squared : + ((r.val : Int) : ZMod 3329) * 2285 * 169 * 169 + = ((x.val : Int) : ZMod 3329) * ((y.val : Int) : ZMod 3329) * 169 * 169 := by + have step : ((r.val : Int) : ZMod 3329) * 2285 * 169 * 169 + = (((r.val : Int) : ZMod 3329) * 2285 * 169) * 169 := by ring + rw [step, h_mul_169] + rw [h_mul_169_squared] + ring + +/-- Associativity of `Spec.Pure.FieldElement.mul_pure` (Mont-domain product + in ZMod q). Used to reshape use_cache per-pair FE equations from + `mul a (mul b c)` form (cache lane = mul rhs zeta) to `mul (mul a b) c` + form to match the L2.8c per-pair FE shape. -/ +theorem L2_8d.mul_pure_assoc + (a b c : hacspec_ml_kem.parameters.FieldElement) : + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure b c) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b) c := by + -- Both sides are canonical; use round-trip + ZMod commutativity. + set lhs : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure b c) with hlhs + set rhs : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b) c with hrhs + have h_canon_lhs : lhs.val.val < 3329 := by + have h_cm := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure a + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure b c) + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cm + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cm; exact h_cm + have h_canon_rhs : rhs.val.val < 3329 := by + have h_cm := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b) c + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cm + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cm; exact h_cm + have h_rt_lhs : feOfZMod (zmodOfFE lhs) = lhs := + feOfZMod_zmodOfFE_of_canonical lhs h_canon_lhs + have h_rt_rhs : feOfZMod (zmodOfFE rhs) = rhs := + feOfZMod_zmodOfFE_of_canonical rhs h_canon_rhs + have h_zmod_eq : zmodOfFE lhs = zmodOfFE rhs := by + rw [hlhs, hrhs] + rw [L2_8c.zmodOfFE_mul_pure, L2_8c.zmodOfFE_mul_pure, + L2_8c.zmodOfFE_mul_pure, L2_8c.zmodOfFE_mul_pure] + ring + rw [← h_rt_lhs, ← h_rt_rhs, h_zmod_eq] + +set_option maxHeartbeats 400000 in +/-- Use-cache variant of `L2_8c.mont_reduce_even_fe_eq`: the per-pair + cache lane appears symbolically (as an I16 `c`) in the RHS via + `lift_fe_mont c`, in place of the `bj * zeta` product. Same Mont- + inversion algebra (`2285 * 169 ≡ 1 (mod q)`). -/ +theorem L2_8d.mont_reduce_even_fe_eq_cache + (out r : Std.I32) (ai bi aj c : Std.I16) + (h_zmod : ((r.val * (2 ^ 16 : Int)) : ZMod 3329) + = ((out.val * (2 ^ 16 : Int) + ai.val * bi.val * (2 ^ 16 : Int) + + aj.val * c.val * (2 ^ 16 : Int)) : ZMod 3329)) : + Spec.mont_reduce_pure (lift_fe_int r.val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bi)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont c))) := by + rw [mont_reduce_pure_lift_fe_int] + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bi)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont c))) with hs_def + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bi)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont c))) + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + have h_zmod_s : zmodOfFE s + = (out.val : ZMod 3329) * 169 * 169 + + ((ai.val : ZMod 3329) * 169 * ((bi.val : ZMod 3329) * 169) + + (aj.val : ZMod 3329) * 169 * ((c.val : ZMod 3329) * 169)) := by + simp only [hs_def, + L2_8c.zmodOfFE_add_pure, + L2_8c.zmodOfFE_mont_reduce_lift_fe_int, + L2_8c.zmodOfFE_mul_pure, + L2_8c.zmodOfFE_lift_fe_mont] + rw [← h_round_trip, h_zmod_s] + congr 1 + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + have h_mul_169_cubed : + (r.val : ZMod 3329) * (2^16 : Int) * 169 * 169 * 169 + = ((out.val : ZMod 3329) * (2^16 : Int) + + (ai.val : ZMod 3329) * (bi.val : ZMod 3329) * (2^16 : Int) + + (aj.val : ZMod 3329) * (c.val : ZMod 3329) * (2^16 : Int)) * 169 * 169 * 169 := by + have := h_zmod + push_cast at this ⊢ + rw [this] + have h_2_16 : ((2^16 : Int) : ZMod 3329) = 2285 := by decide + rw [h_2_16] at h_mul_169_cubed + have h_lhs : + (r.val : ZMod 3329) * 169 * 169 + = (r.val : ZMod 3329) * 2285 * 169 * 169 * 169 := by + have : (r.val : ZMod 3329) * 169 * 169 = (r.val : ZMod 3329) * (2285 * 169) * 169 * 169 := by + rw [h_inv]; ring + rw [this]; ring + rw [h_lhs, h_mul_169_cubed] + have h_expand : ((out.val : ZMod 3329) * 2285 + + (ai.val : ZMod 3329) * (bi.val : ZMod 3329) * 2285 + + (aj.val : ZMod 3329) * (c.val : ZMod 3329) * 2285) + * 169 * 169 * 169 + = (out.val : ZMod 3329) * (2285 * (169 * 169 * 169)) + + (ai.val : ZMod 3329) * (bi.val : ZMod 3329) * (2285 * (169 * 169 * 169)) + + (aj.val : ZMod 3329) * (c.val : ZMod 3329) * (2285 * (169 * 169 * 169)) := by + ring + have h_collapse : ((2285 : ZMod 3329)) * (169 * 169 * 169) = 169 * 169 := by decide + rw [h_expand, h_collapse] + ring + +set_option maxHeartbeats 400000 in +/-- Odd-half analog of `L2_8d.mont_reduce_even_fe_eq_cache`. -/ +theorem L2_8d.mont_reduce_odd_fe_eq_cache + (out r : Std.I32) (ai bi aj bj : Std.I16) + (h_zmod : ((r.val * (2 ^ 16 : Int)) : ZMod 3329) + = ((out.val * (2 ^ 16 : Int) + + ai.val * bj.val * (2 ^ 16 : Int) + + aj.val * bi.val * (2 ^ 16 : Int)) : ZMod 3329)) : + Spec.mont_reduce_pure (lift_fe_int r.val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int out.val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ai) (lift_fe_mont bj)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont aj) (lift_fe_mont bi))) := + -- The odd-half equation has no cache lane (the cached + -- `mont_reduce(b·zeta)` only enters the even-half product); it is + -- identical to `L2_8c.mont_reduce_odd_fe_eq`. + L2_8c.mont_reduce_odd_fe_eq out r ai bi aj bj h_zmod + +set_option maxHeartbeats 8000000 in +/-- Per-pair Triple for `accumulating_ntt_multiply_binomials_fill_cache`. + Sibling of `accumulating_ntt_multiply_binomials_fc`: same 7 POST + conjuncts on the slice output, plus 3 POST conjuncts on the cache + output describing the per-pair Mont-reduced `b[2i+1]·zeta` write + at slot `i`. -/ +theorem accumulating_ntt_multiply_binomials_fill_cache_fc + (a b : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i : Std.Usize) + (out : Aeneas.Std.Slice Std.I32) + (cache : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_i : i.val < 8) + (h_out_len : out.length = 16) + (h_a : ∀ j : Fin 16, (a.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_b : ∀ j : Fin 16, (b.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_zeta : zeta.val.natAbs ≤ 1664) + (h_out_bnd : ∀ k : Fin 16, (out.val[k.val]!).val.natAbs ≤ 2^30 + 2^25) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + a b zeta i out cache + ⦃ ⇓ p => ⌜ p.1.length = 16 + ∧ (∀ k : Nat, k < 16 → k ≠ 2 * i.val → k ≠ 2 * i.val + 1 → + p.1.val[k]! = out.val[k]!) + ∧ (p.1.val[2 * i.val]!).val.natAbs + ≤ (out.val[2 * i.val]!).val.natAbs + 2^25 + ∧ (p.1.val[2 * i.val + 1]!).val.natAbs + ≤ (out.val[2 * i.val + 1]!).val.natAbs + 2^25 + ∧ Spec.mont_reduce_pure (lift_fe_int (p.1.val[2 * i.val]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (out.val[2 * i.val]!).val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val]!)) + (lift_fe_mont (b.elements.val[2 * i.val]!))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val + 1]!)) + (lift_fe_mont (b.elements.val[2 * i.val + 1]!))) + (lift_fe_mont zeta))) + ∧ Spec.mont_reduce_pure (lift_fe_int (p.1.val[2 * i.val + 1]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (out.val[2 * i.val + 1]!).val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val]!)) + (lift_fe_mont (b.elements.val[2 * i.val + 1]!))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val + 1]!)) + (lift_fe_mont (b.elements.val[2 * i.val]!)))) + ∧ (p.2.elements.val[i.val]!).val.natAbs ≤ 3328 + ∧ lift_fe_mont (p.2.elements.val[i.val]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (b.elements.val[2 * i.val + 1]!)) + (lift_fe_mont zeta) + ∧ (∀ k : Nat, k < 16 → k ≠ i.val → + p.2.elements.val[k]! = cache.elements.val[k]!) ⌝ ⦄ := by + -- ===== Setup (identical to L2.8c binomials_fc) ===== + have h_2i_lt : 2 * i.val < 16 := by omega + have h_2i1_lt : 2 * i.val + 1 < 16 := by omega + have h_a_len : a.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length a + have h_b_len : b.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length b + have h_cache_len : cache.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length cache + have h_out_val_len : out.val.length = 16 := h_out_len + set ai_v : Std.I16 := a.elements.val[2 * i.val]! with hai_def + set bi_v : Std.I16 := b.elements.val[2 * i.val]! with hbi_def + set aj_v : Std.I16 := a.elements.val[2 * i.val + 1]! with haj_def + set bj_v : Std.I16 := b.elements.val[2 * i.val + 1]! with hbj_def + have h_ai : ai_v.val.natAbs ≤ 3328 := h_a ⟨2 * i.val, h_2i_lt⟩ + have h_bi : bi_v.val.natAbs ≤ 3328 := h_b ⟨2 * i.val, h_2i_lt⟩ + have h_aj : aj_v.val.natAbs ≤ 3328 := h_a ⟨2 * i.val + 1, h_2i1_lt⟩ + have h_bj : bj_v.val.natAbs ≤ 3328 := h_b ⟨2 * i.val + 1, h_2i1_lt⟩ + set old_e : Std.I32 := out.val[2 * i.val]! with hoe_def + set old_o : Std.I32 := out.val[2 * i.val + 1]! with hoo_def + have h_old_e_bnd : old_e.val.natAbs ≤ 2^30 + 2^25 := h_out_bnd ⟨2 * i.val, h_2i_lt⟩ + have h_old_o_bnd : old_o.val.natAbs ≤ 2^30 + 2^25 := h_out_bnd ⟨2 * i.val + 1, h_2i1_lt⟩ + -- ===== Index arithmetic ===== + obtain ⟨i1, h_i1_eq, h_i1_val⟩ := + usize_mul_ok_eq_fc 2#usize i (by scalar_tac) + have h_i1_val' : i1.val = 2 * i.val := by + rw [h_i1_val]; rfl + obtain ⟨i2, h_i2_eq, h_i2_val⟩ := + usize_add_ok_eq_fc i1 1#usize (by scalar_tac) + have h_i2_val' : i2.val = 2 * i.val + 1 := by + rw [h_i2_val, h_i1_val']; rfl + -- ===== Reads (with index_usize_ok_eq) ===== + have h_read_ai : + Aeneas.Std.Array.index_usize a.elements i1 = .ok ai_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a.elements i1 + (by rw [h_a_len, h_i1_val']; exact h_2i_lt) + rw [h, h_i1_val'] + have h_read_bi : + Aeneas.Std.Array.index_usize b.elements i1 = .ok bi_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq b.elements i1 + (by rw [h_b_len, h_i1_val']; exact h_2i_lt) + rw [h, h_i1_val'] + have h_read_aj : + Aeneas.Std.Array.index_usize a.elements i2 = .ok aj_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a.elements i2 + (by rw [h_a_len, h_i2_val']; exact h_2i1_lt) + rw [h, h_i2_val'] + have h_read_bj : + Aeneas.Std.Array.index_usize b.elements i2 = .ok bj_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq b.elements i2 + (by rw [h_b_len, h_i2_val']; exact h_2i1_lt) + rw [h, h_i2_val'] + -- ===== as_i32 casts ===== + set ai32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 ai_v with hai32_def + set bi32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 bi_v with hbi32_def + set aj32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 aj_v with haj32_def + set bj32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 bj_v with hbj32_def + set zeta32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 zeta with hzeta32_def + have h_ai32_val : ai32.val = ai_v.val := L2_8c.cast_I32_val ai_v + have h_bi32_val : bi32.val = bi_v.val := L2_8c.cast_I32_val bi_v + have h_aj32_val : aj32.val = aj_v.val := L2_8c.cast_I32_val aj_v + have h_bj32_val : bj32.val = bj_v.val := L2_8c.cast_I32_val bj_v + have h_zeta32_val : zeta32.val = zeta.val := L2_8c.cast_I32_val zeta + have h_as_ai : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 ai_v = .ok ai32 := + L2_8c.as_i32_val_eq ai_v + have h_as_bi : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 bi_v = .ok bi32 := + L2_8c.as_i32_val_eq bi_v + have h_as_aj : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 aj_v = .ok aj32 := + L2_8c.as_i32_val_eq aj_v + have h_as_bj : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 bj_v = .ok bj32 := + L2_8c.as_i32_val_eq bj_v + have h_as_zeta : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 zeta = .ok zeta32 := + L2_8c.as_i32_val_eq zeta + -- ===== Step: ai_bi = wrapping_mul ai32 bi32 ===== + set ai_bi : Std.I32 := Aeneas.Std.I32.wrapping_mul ai32 bi32 with habi_def + have h_ai_bi_eq : CoreModels.core.num.I32.wrapping_mul ai32 bi32 = .ok ai_bi := + L2_8c.cm_wrapping_mul_i32_ok_eq ai32 bi32 + have h_ai_bi_val : ai_bi.val = ai_v.val * bi_v.val := by + have h_bnd : (ai32.val * bi32.val).natAbs < 2^31 := by + rw [h_ai32_val, h_bi32_val] + have h := Int.natAbs_mul ai_v.val bi_v.val + have : ai_v.val.natAbs * bi_v.val.natAbs ≤ 3328 * 3328 := by + exact Nat.mul_le_mul h_ai h_bi + rw [h] + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow ai32 bi32 h_bnd + rw [this, h_ai32_val, h_bi32_val] + -- ===== Step: bj_zeta_ = wrapping_mul bj32 zeta32 ===== + set bj_zeta_ : Std.I32 := Aeneas.Std.I32.wrapping_mul bj32 zeta32 with hbjz_def + have h_bj_zeta_eq : CoreModels.core.num.I32.wrapping_mul bj32 zeta32 = .ok bj_zeta_ := + L2_8c.cm_wrapping_mul_i32_ok_eq bj32 zeta32 + have h_bj_zeta_val : bj_zeta_.val = bj_v.val * zeta.val := by + have h_bnd : (bj32.val * zeta32.val).natAbs < 2^31 := by + rw [h_bj32_val, h_zeta32_val] + rw [Int.natAbs_mul] + have h_mul : bj_v.val.natAbs * zeta.val.natAbs ≤ 3328 * 1664 := + Nat.mul_le_mul h_bj h_zeta + have : (3328 * 1664 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow bj32 zeta32 h_bnd + rw [this, h_bj32_val, h_zeta32_val] + -- ===== Step: bj_zeta = montgomery_reduce_element bj_zeta_ ===== + have h_bj_zeta_pre : bj_zeta_.val.natAbs ≤ 2^16 * 3328 := by + rw [h_bj_zeta_val] + rw [Int.natAbs_mul] + have h_mul : bj_v.val.natAbs * zeta.val.natAbs ≤ 3328 * 1664 := + Nat.mul_le_mul h_bj h_zeta + have : (3328 * 1664 : Nat) ≤ 2^16 * 3328 := by decide + omega + obtain ⟨bj_zeta, h_bj_zeta_ok, h_bj_zeta_bnd, h_bj_zeta_lift⟩ := + triple_exists_ok_fc (montgomery_reduce_element_fc bj_zeta_ h_bj_zeta_pre) + have h_bj_zeta_pre' : bj_zeta_.val.natAbs ≤ 3328 * 2^16 := by + rw [show (3328 * 2^16 : Nat) = 2^16 * 3328 from by decide]; exact h_bj_zeta_pre + obtain ⟨bj_zeta', h_bj_zeta_ok', _h_bnd', h_tight_imp, h_bj_zeta_modq⟩ := + triple_exists_ok_fc + (libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_reduce_element_spec bj_zeta_ h_bj_zeta_pre') + have h_bj_zeta_eq2 : bj_zeta = bj_zeta' := by + have h_both : (Result.ok bj_zeta : Result _) = Result.ok bj_zeta' := by + rw [← h_bj_zeta_ok, h_bj_zeta_ok'] + cases h_both; rfl + -- Tight canonical bound for the cache POST: |bj * zeta| ≤ 3328 * 1664 ≤ 3328 * 2^15 + -- discharges the conditional in `montgomery_reduce_element_spec`. + have h_bj_zeta_tight_pre : bj_zeta_.val.natAbs ≤ 3328 * 2^15 := by + rw [h_bj_zeta_val, Int.natAbs_mul] + have h_mul : bj_v.val.natAbs * zeta.val.natAbs ≤ 3328 * 1664 := + Nat.mul_le_mul h_bj h_zeta + have h_le : (3328 * 1664 : Nat) ≤ 3328 * 2^15 := by decide + omega + have h_bj_zeta_canon : bj_zeta.val.natAbs ≤ 3328 := by + rw [h_bj_zeta_eq2]; exact h_tight_imp h_bj_zeta_tight_pre + -- ===== Cache write step (NEW): Array.update cache.elements i bj_zeta = .ok (cache.set i bj_zeta) ===== + have h_upd_cache : + Aeneas.Std.Array.update cache.elements i bj_zeta + = .ok (cache.elements.set i bj_zeta) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq cache.elements i bj_zeta + (by rw [h_cache_len]; exact (by omega : i.val < 16)) + set cache_new : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { elements := cache.elements.set i bj_zeta } with hcn_def + -- ===== Step: aj_bj_zeta = wrapping_mul aj32 (as_i32 bj_zeta) ===== + set bj_zeta32 : Std.I32 := + Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 bj_zeta with hbjz32_def + have h_bj_zeta32_val : bj_zeta32.val = bj_zeta.val := L2_8c.cast_I32_val bj_zeta + have h_as_bj_zeta : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 bj_zeta + = .ok bj_zeta32 := L2_8c.as_i32_val_eq bj_zeta + set aj_bj_zeta : Std.I32 := Aeneas.Std.I32.wrapping_mul aj32 bj_zeta32 with habjz_def + have h_aj_bj_zeta_eq : CoreModels.core.num.I32.wrapping_mul aj32 bj_zeta32 = .ok aj_bj_zeta := + L2_8c.cm_wrapping_mul_i32_ok_eq aj32 bj_zeta32 + have h_aj_bj_zeta_val : aj_bj_zeta.val = aj_v.val * bj_zeta.val := by + have h_bnd : (aj32.val * bj_zeta32.val).natAbs < 2^31 := by + rw [h_aj32_val, h_bj_zeta32_val, Int.natAbs_mul] + have h_mul : aj_v.val.natAbs * bj_zeta.val.natAbs ≤ 3328 * (3328 + 1665) := + Nat.mul_le_mul h_aj h_bj_zeta_bnd + have : (3328 * (3328 + 1665) : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow aj32 bj_zeta32 h_bnd + rw [this, h_aj32_val, h_bj_zeta32_val] + -- ===== Step: ai_bi_aj_bj = wrapping_add ai_bi aj_bj_zeta ===== + set ai_bi_aj_bj : Std.I32 := Aeneas.Std.I32.wrapping_add ai_bi aj_bj_zeta with hsum_e_def + have h_sum_e_eq : CoreModels.core.num.I32.wrapping_add ai_bi aj_bj_zeta = .ok ai_bi_aj_bj := + L2_8c.cm_wrapping_add_i32_ok_eq ai_bi aj_bj_zeta + have h_sum_e_bnd : (ai_bi.val + aj_bj_zeta.val).natAbs ≤ 3328 * 3328 + 3328 * (3328 + 1665) := by + rw [h_ai_bi_val, h_aj_bj_zeta_val] + have h_e1 : (ai_v.val * bi_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_ai h_bi + have h_e2 : (aj_v.val * bj_zeta.val).natAbs ≤ 3328 * (3328 + 1665) := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_aj h_bj_zeta_bnd + have h_tri : ((ai_v.val * bi_v.val) + (aj_v.val * bj_zeta.val)).natAbs + ≤ (ai_v.val * bi_v.val).natAbs + (aj_v.val * bj_zeta.val).natAbs := + Int.natAbs_add_le _ _ + omega + have h_sum_e_val : ai_bi_aj_bj.val = ai_bi.val + aj_bj_zeta.val := by + have h_bnd : (ai_bi.val + aj_bj_zeta.val).natAbs < 2^31 := by + have h_le : (3328 * 3328 + 3328 * (3328 + 1665) : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow ai_bi aj_bj_zeta h_bnd + have h_delta_e_bnd : ai_bi_aj_bj.val.natAbs ≤ 2^25 := by + rw [h_sum_e_val] + have : (3328 * 3328 + 3328 * (3328 + 1665) : Nat) ≤ 2^25 := by decide + omega + -- ===== Step: ai_bj = wrapping_mul ai32 bj32 ===== + set ai_bj_p : Std.I32 := Aeneas.Std.I32.wrapping_mul ai32 bj32 with haibj_def + have h_ai_bj_eq : CoreModels.core.num.I32.wrapping_mul ai32 bj32 = .ok ai_bj_p := + L2_8c.cm_wrapping_mul_i32_ok_eq ai32 bj32 + have h_ai_bj_val : ai_bj_p.val = ai_v.val * bj_v.val := by + have h_bnd : (ai32.val * bj32.val).natAbs < 2^31 := by + rw [h_ai32_val, h_bj32_val, Int.natAbs_mul] + have h_mul : ai_v.val.natAbs * bj_v.val.natAbs ≤ 3328 * 3328 := + Nat.mul_le_mul h_ai h_bj + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow ai32 bj32 h_bnd + rw [this, h_ai32_val, h_bj32_val] + -- ===== Step: aj_bi = wrapping_mul aj32 bi32 ===== + set aj_bi_p : Std.I32 := Aeneas.Std.I32.wrapping_mul aj32 bi32 with hajbi_def + have h_aj_bi_eq : CoreModels.core.num.I32.wrapping_mul aj32 bi32 = .ok aj_bi_p := + L2_8c.cm_wrapping_mul_i32_ok_eq aj32 bi32 + have h_aj_bi_val : aj_bi_p.val = aj_v.val * bi_v.val := by + have h_bnd : (aj32.val * bi32.val).natAbs < 2^31 := by + rw [h_aj32_val, h_bi32_val, Int.natAbs_mul] + have h_mul : aj_v.val.natAbs * bi_v.val.natAbs ≤ 3328 * 3328 := + Nat.mul_le_mul h_aj h_bi + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow aj32 bi32 h_bnd + rw [this, h_aj32_val, h_bi32_val] + -- ===== Step: ai_bj_aj_bi = wrapping_add ai_bj aj_bi ===== + set ai_bj_aj_bi : Std.I32 := Aeneas.Std.I32.wrapping_add ai_bj_p aj_bi_p with hsum_o_def + have h_sum_o_eq : CoreModels.core.num.I32.wrapping_add ai_bj_p aj_bi_p = .ok ai_bj_aj_bi := + L2_8c.cm_wrapping_add_i32_ok_eq ai_bj_p aj_bi_p + have h_sum_o_bnd : (ai_bj_p.val + aj_bi_p.val).natAbs ≤ 2 * 3328 * 3328 := by + rw [h_ai_bj_val, h_aj_bi_val] + have h_e1 : (ai_v.val * bj_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_ai h_bj + have h_e2 : (aj_v.val * bi_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_aj h_bi + have h_tri := Int.natAbs_add_le (ai_v.val * bj_v.val) (aj_v.val * bi_v.val) + omega + have h_sum_o_val : ai_bj_aj_bi.val = ai_bj_p.val + aj_bi_p.val := by + have h_bnd : (ai_bj_p.val + aj_bi_p.val).natAbs < 2^31 := by + have : (2 * 3328 * 3328 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow ai_bj_p aj_bi_p h_bnd + have h_delta_o_bnd : ai_bj_aj_bi.val.natAbs ≤ 2^25 := by + rw [h_sum_o_val] + have : (2 * 3328 * 3328 : Nat) ≤ 2^25 := by decide + omega + -- ===== Slice reads + writes for `out` ===== + have h_read_old_e : Aeneas.Std.Slice.index_usize out i1 = .ok old_e := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq out i1 + (by rw [h_out_val_len, h_i1_val']; exact h_2i_lt) + rw [h, h_i1_val'] + set new_e : Std.I32 := Aeneas.Std.I32.wrapping_add old_e ai_bi_aj_bj with hne_def + have h_new_e_eq : CoreModels.core.num.I32.wrapping_add old_e ai_bi_aj_bj = .ok new_e := + L2_8c.cm_wrapping_add_i32_ok_eq old_e ai_bi_aj_bj + have h_new_e_val : new_e.val = old_e.val + ai_bi_aj_bj.val := by + have h_bnd : (old_e.val + ai_bi_aj_bj.val).natAbs < 2^31 := by + have h_tri := Int.natAbs_add_le old_e.val ai_bi_aj_bj.val + have : (2^30 + 2^25 + 2^25 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow old_e ai_bi_aj_bj h_bnd + have h_new_e_bnd : new_e.val.natAbs ≤ old_e.val.natAbs + 2^25 := by + rw [h_new_e_val] + have h_tri := Int.natAbs_add_le old_e.val ai_bi_aj_bj.val + omega + have h_upd_e : Aeneas.Std.Slice.update out i1 new_e = .ok (out.set i1 new_e) := by + have hT := Aeneas.Std.Slice.update_spec out i1 new_e (by rw [h_out_len, h_i1_val']; exact h_2i_lt) + obtain ⟨v', h_eq, h_v'⟩ := Aeneas.Std.WP.spec_imp_exists hT + rw [h_eq, h_v'] + set out1 : Aeneas.Std.Slice Std.I32 := out.set i1 new_e with hout1_def + have h_out1_len : out1.length = 16 := by simp [hout1_def]; exact h_out_len + have h_out1_val_len : out1.val.length = 16 := h_out1_len + have h_old_o_in_out1 : out1.val[i2.val]! = old_o := by + have h_set_val : out1.val = out.val.set i1.val new_e := by + simp [hout1_def, Aeneas.Std.Slice.set_val_eq] + have h_ne : 2 * i.val + 1 ≠ i1.val := by rw [h_i1_val']; omega + have h_lt : 2 * i.val + 1 < out.val.length := by rw [h_out_val_len]; exact h_2i1_lt + rw [h_set_val, h_i2_val', hoo_def] + have h_lt_set : 2 * i.val + 1 < (out.val.set i1.val new_e).length := by + rw [List.length_set]; exact h_lt + rw [getElem!_pos (out.val.set i1.val new_e) (2 * i.val + 1) h_lt_set] + rw [getElem!_pos out.val (2 * i.val + 1) h_lt] + rw [List.getElem_set_ne (Ne.symm h_ne)] + have h_read_old_o : Aeneas.Std.Slice.index_usize out1 i2 = .ok old_o := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq out1 i2 + (by rw [h_out1_val_len, h_i2_val']; exact h_2i1_lt) + rw [h, h_old_o_in_out1] + set new_o : Std.I32 := Aeneas.Std.I32.wrapping_add old_o ai_bj_aj_bi with hno_def + have h_new_o_eq : CoreModels.core.num.I32.wrapping_add old_o ai_bj_aj_bi = .ok new_o := + L2_8c.cm_wrapping_add_i32_ok_eq old_o ai_bj_aj_bi + have h_new_o_val : new_o.val = old_o.val + ai_bj_aj_bi.val := by + have h_bnd : (old_o.val + ai_bj_aj_bi.val).natAbs < 2^31 := by + have h_tri := Int.natAbs_add_le old_o.val ai_bj_aj_bi.val + have : (2^30 + 2^25 + 2^25 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow old_o ai_bj_aj_bi h_bnd + have h_new_o_bnd : new_o.val.natAbs ≤ old_o.val.natAbs + 2^25 := by + rw [h_new_o_val] + have h_tri := Int.natAbs_add_le old_o.val ai_bj_aj_bi.val + omega + have h_upd_o : Aeneas.Std.Slice.update out1 i2 new_o = .ok (out1.set i2 new_o) := by + have hT := Aeneas.Std.Slice.update_spec out1 i2 new_o + (by rw [h_out1_len, h_i2_val']; exact h_2i1_lt) + obtain ⟨v', h_eq, h_v'⟩ := Aeneas.Std.WP.spec_imp_exists hT + rw [h_eq, h_v'] + set out2 : Aeneas.Std.Slice Std.I32 := out1.set i2 new_o with hout2_def + -- ===== Compose the monadic body ===== + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + a b zeta i out cache = .ok (out2, cache_new) := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + simp only [h_i1_eq, h_i2_eq, h_read_ai, h_read_bi, h_read_aj, h_read_bj, + h_as_ai, h_as_bi, h_as_aj, h_as_bj, h_as_zeta, h_as_bj_zeta, + h_ai_bi_eq, h_bj_zeta_eq, h_bj_zeta_ok, h_upd_cache, h_aj_bj_zeta_eq, + h_sum_e_eq, h_ai_bj_eq, h_aj_bi_eq, h_sum_o_eq, + h_read_old_e, h_new_e_eq, h_upd_e, + h_read_old_o, h_new_o_eq, h_upd_o, + Aeneas.Std.bind_tc_ok] + rfl + apply triple_of_ok_fc h_body + -- ===== POST: 10-conjunct ===== + -- out2 unfolding (same as L2.8c). + have h_out2_val : out2.val = (out.val.set i1.val new_e).set i2.val new_o := by + show ((out.set i1 new_e).set i2 new_o).val = _ + rw [Aeneas.Std.Slice.set_val_eq, Aeneas.Std.Slice.set_val_eq] + have h_out2_len : out2.length = 16 := by + show ((out.set i1 new_e).set i2 new_o).length = 16 + rw [Aeneas.Std.Slice.set_length, Aeneas.Std.Slice.set_length]; exact h_out_len + have h_out2_val_len : out2.val.length = 16 := h_out2_len + have h_out2_at_2i : out2.val[2 * i.val]! = new_e := by + rw [h_out2_val, ← h_i1_val'] + have h_lt_out : i1.val < out.val.length := by rw [h_out_val_len, h_i1_val']; exact h_2i_lt + have h_lt1 : i1.val < (out.val.set i1.val new_e).length := by + rw [List.length_set]; exact h_lt_out + have h_lt2 : i1.val < ((out.val.set i1.val new_e).set i2.val new_o).length := by + rw [List.length_set]; exact h_lt1 + rw [getElem!_pos ((out.val.set i1.val new_e).set i2.val new_o) i1.val h_lt2] + rw [List.getElem_set_ne (by rw [h_i2_val', h_i1_val']; omega)] + rw [List.getElem_set_self] + have h_out2_at_2i1 : out2.val[2 * i.val + 1]! = new_o := by + rw [h_out2_val, ← h_i2_val'] + have h_lt_out : i2.val < out.val.length := by rw [h_out_val_len, h_i2_val']; exact h_2i1_lt + have h_lt1 : i2.val < (out.val.set i1.val new_e).length := by + rw [List.length_set]; exact h_lt_out + have h_lt2 : i2.val < ((out.val.set i1.val new_e).set i2.val new_o).length := by + rw [List.length_set]; exact h_lt1 + rw [getElem!_pos ((out.val.set i1.val new_e).set i2.val new_o) i2.val h_lt2] + rw [List.getElem_set_self] + have h_out2_untouched : ∀ k : Nat, k < 16 → k ≠ 2 * i.val → k ≠ 2 * i.val + 1 → + out2.val[k]! = out.val[k]! := by + intro k hk hki hkj + rw [h_out2_val] + have h_lt_out : k < out.val.length := by rw [h_out_val_len]; exact hk + have h_lt1 : k < (out.val.set i1.val new_e).length := by rw [List.length_set]; exact h_lt_out + have h_lt2 : k < ((out.val.set i1.val new_e).set i2.val new_o).length := by + rw [List.length_set]; exact h_lt1 + rw [getElem!_pos ((out.val.set i1.val new_e).set i2.val new_o) k h_lt2] + rw [getElem!_pos out.val k h_lt_out] + rw [List.getElem_set_ne (by rw [h_i2_val']; omega)] + rw [List.getElem_set_ne (by rw [h_i1_val']; omega)] + -- ===== Cache POST conjuncts ===== + -- cache_new.elements.val = cache.elements.val.set i.val bj_zeta. + have h_cache_new_val : + cache_new.elements.val = cache.elements.val.set i.val bj_zeta := by + simp [hcn_def, Aeneas.Std.Array.set_val_eq] + have h_cache_val_len : cache.elements.val.length = 16 := h_cache_len + have h_cache_at_i : cache_new.elements.val[i.val]! = bj_zeta := by + rw [h_cache_new_val] + have h_lt : i.val < cache.elements.val.length := by rw [h_cache_val_len]; omega + have h_lt_set : i.val < (cache.elements.val.set i.val bj_zeta).length := by + rw [List.length_set]; exact h_lt + rw [getElem!_pos (cache.elements.val.set i.val bj_zeta) i.val h_lt_set] + rw [List.getElem_set_self] + have h_cache_untouched : ∀ k : Nat, k < 16 → k ≠ i.val → + cache_new.elements.val[k]! = cache.elements.val[k]! := by + intro k hk hki + rw [h_cache_new_val] + have h_lt : k < cache.elements.val.length := by rw [h_cache_val_len]; exact hk + have h_lt_set : k < (cache.elements.val.set i.val bj_zeta).length := by + rw [List.length_set]; exact h_lt + rw [getElem!_pos (cache.elements.val.set i.val bj_zeta) k h_lt_set] + rw [getElem!_pos cache.elements.val k h_lt] + rw [List.getElem_set_ne (Ne.symm hki)] + -- The cache POST FE-equation conjunct: lift_fe_mont bj_zeta = mul (lift bj) (lift zeta). + have h_cache_fe : + lift_fe_mont bj_zeta + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont bj_v) (lift_fe_mont zeta) := by + -- Convert h_bj_zeta_modq to ZMod equation and invoke the helper. + have h_modq_cast : ((bj_zeta'.val : Int) : ZMod 3329) + = ((bj_zeta_.val * 169 : Int) : ZMod 3329) := + modq_eq_cast_zmod _ _ h_bj_zeta_modq + rw [h_bj_zeta_eq2.symm] at h_modq_cast + rw [h_bj_zeta_val] at h_modq_cast + push_cast at h_modq_cast + -- h_modq_cast : (bj_zeta.val : ZMod q) = (bj.val * zeta.val * 169 : ZMod q). + apply L2_8d.lift_fe_mont_of_mont_reduce_modq bj_zeta bj_v zeta h_bj_zeta_canon + -- Goal: ((bj_zeta.val : Int) : ZMod q) * 2^16 = (bj.val : ZMod q) * (zeta.val : ZMod q). + push_cast + rw [h_modq_cast] + -- Goal: bj * zeta * 169 * 2^16 = bj * zeta. Use 2^16 * 169 = 1. + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + rw [show (((bj_v.val : Int) : ZMod 3329) * ((zeta.val : Int) : ZMod 3329) * 169) * 2285 + = ((bj_v.val : Int) : ZMod 3329) * ((zeta.val : Int) : ZMod 3329) * (169 * 2285) + from by ring] + rw [show ((169 : ZMod 3329) * 2285) = (2285 * 169 : ZMod 3329) from by ring] + rw [h_inv] + ring + -- ===== Assemble the 10-conjunct POST ===== + -- Note: target POST mentions `p.2.elements.val[k]!` (where p = (out2, cache_new)), + -- which is `cache_new.elements.val[k]!`. The tight canonical bound for the + -- cache POST was established above as `h_bj_zeta_canon`. + refine ⟨h_out2_len, h_out2_untouched, ?_, ?_, ?_, ?_, ?_, ?_, ?_⟩ + · -- Bound at 2*i. + rw [h_out2_at_2i, hoe_def] at * + rw [h_out2_at_2i] + rw [hoe_def] at h_new_e_bnd + exact h_new_e_bnd + · -- Bound at 2*i+1. + rw [h_out2_at_2i1] + rw [hoo_def] at h_new_o_bnd + exact h_new_o_bnd + · -- FE eq (even half) — identical algebra to L2.8c binomials_fc. + rw [h_out2_at_2i, hoe_def] + have h_modq_cast : ((bj_zeta'.val : Int) : ZMod 3329) + = ((bj_zeta_.val * 169 : Int) : ZMod 3329) := + modq_eq_cast_zmod _ _ h_bj_zeta_modq + rw [h_bj_zeta_eq2.symm] at h_modq_cast + rw [h_bj_zeta_val] at h_modq_cast + push_cast at h_modq_cast + apply L2_8c.mont_reduce_even_fe_eq + (out := out.val[2 * i.val]!) (r := new_e) + (ai := ai_v) (bi := bi_v) (aj := aj_v) (bj := bj_v) (zeta := zeta) + rw [← hoe_def, h_new_e_val, h_sum_e_val, h_ai_bi_val, h_aj_bj_zeta_val] + push_cast + rw [h_modq_cast] + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + calc ((old_e.val : ZMod 3329) + ((ai_v.val : ZMod 3329) * (bi_v.val : ZMod 3329) + + (aj_v.val : ZMod 3329) * ((bj_v.val : ZMod 3329) * (zeta.val : ZMod 3329) * 169))) + * 2285 + = (old_e.val : ZMod 3329) * 2285 + + (ai_v.val : ZMod 3329) * (bi_v.val : ZMod 3329) * 2285 + + (aj_v.val : ZMod 3329) * (bj_v.val : ZMod 3329) * (zeta.val : ZMod 3329) + * (2285 * 169) := by ring + _ = (old_e.val : ZMod 3329) * 2285 + + (ai_v.val : ZMod 3329) * (bi_v.val : ZMod 3329) * 2285 + + (aj_v.val : ZMod 3329) * (bj_v.val : ZMod 3329) * (zeta.val : ZMod 3329) := by + rw [h_inv]; ring + · -- FE eq (odd half). + rw [h_out2_at_2i1, hoo_def] + apply L2_8c.mont_reduce_odd_fe_eq + (out := out.val[2 * i.val + 1]!) (r := new_o) + (ai := ai_v) (bi := bi_v) (aj := aj_v) (bj := bj_v) + rw [← hoo_def, h_new_o_val, h_sum_o_val, h_ai_bj_val, h_aj_bi_val] + push_cast + ring + · -- Cache canonicity at slot i. + rw [h_cache_at_i]; exact h_bj_zeta_canon + · -- Cache FE-equation at slot i. + rw [h_cache_at_i, hbj_def]; exact h_cache_fe + · -- Cache unchanged outside slot i. + exact h_cache_untouched + +set_option maxHeartbeats 8000000 in +/-- Per-pair Triple for `accumulating_ntt_multiply_binomials_use_cache`. + Reads `cache[i]` (an I16) in place of the per-pair Mont-reduced + `b[2i+1]·zeta`. The cache PRE conjunct asserts canonicity + (`|cache[i].val| ≤ 3328`); the FE equation for the even half + leaves the cached lane symbolic (the outer use_cache Triple + rewrites it using the cache PRE FE equation). -/ +theorem accumulating_ntt_multiply_binomials_use_cache_fc + (a b : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (i : Std.Usize) + (out : Aeneas.Std.Slice Std.I32) + (cache : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_i : i.val < 8) + (h_out_len : out.length = 16) + (h_a : ∀ j : Fin 16, (a.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_b : ∀ j : Fin 16, (b.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_cache_i : (cache.elements.val[i.val]!).val.natAbs ≤ 3328) + (h_out_bnd : ∀ k : Fin 16, (out.val[k.val]!).val.natAbs ≤ 2^30 + 2^25) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_use_cache + a b i out cache + ⦃ ⇓ r => ⌜ r.length = 16 + ∧ (∀ k : Nat, k < 16 → k ≠ 2 * i.val → k ≠ 2 * i.val + 1 → + r.val[k]! = out.val[k]!) + ∧ (r.val[2 * i.val]!).val.natAbs + ≤ (out.val[2 * i.val]!).val.natAbs + 2^25 + ∧ (r.val[2 * i.val + 1]!).val.natAbs + ≤ (out.val[2 * i.val + 1]!).val.natAbs + 2^25 + ∧ Spec.mont_reduce_pure (lift_fe_int (r.val[2 * i.val]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (out.val[2 * i.val]!).val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val]!)) + (lift_fe_mont (b.elements.val[2 * i.val]!))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val + 1]!)) + (lift_fe_mont (cache.elements.val[i.val]!)))) + ∧ Spec.mont_reduce_pure (lift_fe_int (r.val[2 * i.val + 1]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (out.val[2 * i.val + 1]!).val)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val]!)) + (lift_fe_mont (b.elements.val[2 * i.val + 1]!))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (a.elements.val[2 * i.val + 1]!)) + (lift_fe_mont (b.elements.val[2 * i.val]!)))) ⌝ ⦄ := by + -- ===== Setup ===== + have h_2i_lt : 2 * i.val < 16 := by omega + have h_2i1_lt : 2 * i.val + 1 < 16 := by omega + have h_a_len : a.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length a + have h_b_len : b.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length b + have h_cache_len : cache.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length cache + have h_out_val_len : out.val.length = 16 := h_out_len + set ai_v : Std.I16 := a.elements.val[2 * i.val]! with hai_def + set bi_v : Std.I16 := b.elements.val[2 * i.val]! with hbi_def + set aj_v : Std.I16 := a.elements.val[2 * i.val + 1]! with haj_def + set bj_v : Std.I16 := b.elements.val[2 * i.val + 1]! with hbj_def + set c_v : Std.I16 := cache.elements.val[i.val]! with hcv_def + have h_ai : ai_v.val.natAbs ≤ 3328 := h_a ⟨2 * i.val, h_2i_lt⟩ + have h_bi : bi_v.val.natAbs ≤ 3328 := h_b ⟨2 * i.val, h_2i_lt⟩ + have h_aj : aj_v.val.natAbs ≤ 3328 := h_a ⟨2 * i.val + 1, h_2i1_lt⟩ + have h_bj : bj_v.val.natAbs ≤ 3328 := h_b ⟨2 * i.val + 1, h_2i1_lt⟩ + have h_cv : c_v.val.natAbs ≤ 3328 := h_cache_i + set old_e : Std.I32 := out.val[2 * i.val]! with hoe_def + set old_o : Std.I32 := out.val[2 * i.val + 1]! with hoo_def + have h_old_e_bnd : old_e.val.natAbs ≤ 2^30 + 2^25 := h_out_bnd ⟨2 * i.val, h_2i_lt⟩ + have h_old_o_bnd : old_o.val.natAbs ≤ 2^30 + 2^25 := h_out_bnd ⟨2 * i.val + 1, h_2i1_lt⟩ + -- ===== Index arithmetic ===== + obtain ⟨i1, h_i1_eq, h_i1_val⟩ := + usize_mul_ok_eq_fc 2#usize i (by scalar_tac) + have h_i1_val' : i1.val = 2 * i.val := by + rw [h_i1_val]; rfl + obtain ⟨i2, h_i2_eq, h_i2_val⟩ := + usize_add_ok_eq_fc i1 1#usize (by scalar_tac) + have h_i2_val' : i2.val = 2 * i.val + 1 := by + rw [h_i2_val, h_i1_val']; rfl + -- ===== Reads (with index_usize_ok_eq) ===== + have h_read_ai : + Aeneas.Std.Array.index_usize a.elements i1 = .ok ai_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a.elements i1 + (by rw [h_a_len, h_i1_val']; exact h_2i_lt) + rw [h, h_i1_val'] + have h_read_bi : + Aeneas.Std.Array.index_usize b.elements i1 = .ok bi_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq b.elements i1 + (by rw [h_b_len, h_i1_val']; exact h_2i_lt) + rw [h, h_i1_val'] + have h_read_aj : + Aeneas.Std.Array.index_usize a.elements i2 = .ok aj_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a.elements i2 + (by rw [h_a_len, h_i2_val']; exact h_2i1_lt) + rw [h, h_i2_val'] + have h_read_bj : + Aeneas.Std.Array.index_usize b.elements i2 = .ok bj_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq b.elements i2 + (by rw [h_b_len, h_i2_val']; exact h_2i1_lt) + rw [h, h_i2_val'] + have h_read_cv : + Aeneas.Std.Array.index_usize cache.elements i = .ok c_v := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq cache.elements i + (by rw [h_cache_len]; exact (by omega : i.val < 16)) + rw [h] + -- ===== as_i32 casts ===== + set ai32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 ai_v with hai32_def + set bi32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 bi_v with hbi32_def + set aj32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 aj_v with haj32_def + set bj32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 bj_v with hbj32_def + set c32 : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 c_v with hc32_def + have h_ai32_val : ai32.val = ai_v.val := L2_8c.cast_I32_val ai_v + have h_bi32_val : bi32.val = bi_v.val := L2_8c.cast_I32_val bi_v + have h_aj32_val : aj32.val = aj_v.val := L2_8c.cast_I32_val aj_v + have h_bj32_val : bj32.val = bj_v.val := L2_8c.cast_I32_val bj_v + have h_c32_val : c32.val = c_v.val := L2_8c.cast_I32_val c_v + have h_as_ai : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 ai_v = .ok ai32 := + L2_8c.as_i32_val_eq ai_v + have h_as_bi : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 bi_v = .ok bi32 := + L2_8c.as_i32_val_eq bi_v + have h_as_aj : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 aj_v = .ok aj32 := + L2_8c.as_i32_val_eq aj_v + have h_as_bj : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 bj_v = .ok bj32 := + L2_8c.as_i32_val_eq bj_v + have h_as_cv : libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32 c_v = .ok c32 := + L2_8c.as_i32_val_eq c_v + -- ===== Step: ai_bi = wrapping_mul ai32 bi32 ===== + set ai_bi : Std.I32 := Aeneas.Std.I32.wrapping_mul ai32 bi32 with habi_def + have h_ai_bi_eq : CoreModels.core.num.I32.wrapping_mul ai32 bi32 = .ok ai_bi := + L2_8c.cm_wrapping_mul_i32_ok_eq ai32 bi32 + have h_ai_bi_val : ai_bi.val = ai_v.val * bi_v.val := by + have h_bnd : (ai32.val * bi32.val).natAbs < 2^31 := by + rw [h_ai32_val, h_bi32_val] + have h := Int.natAbs_mul ai_v.val bi_v.val + have : ai_v.val.natAbs * bi_v.val.natAbs ≤ 3328 * 3328 := by + exact Nat.mul_le_mul h_ai h_bi + rw [h] + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow ai32 bi32 h_bnd + rw [this, h_ai32_val, h_bi32_val] + -- ===== Step: aj_bj_zeta = wrapping_mul aj32 c32 (uses cache directly) ===== + set aj_bj_zeta : Std.I32 := Aeneas.Std.I32.wrapping_mul aj32 c32 with habjz_def + have h_aj_bj_zeta_eq : CoreModels.core.num.I32.wrapping_mul aj32 c32 = .ok aj_bj_zeta := + L2_8c.cm_wrapping_mul_i32_ok_eq aj32 c32 + have h_aj_bj_zeta_val : aj_bj_zeta.val = aj_v.val * c_v.val := by + have h_bnd : (aj32.val * c32.val).natAbs < 2^31 := by + rw [h_aj32_val, h_c32_val, Int.natAbs_mul] + have h_mul : aj_v.val.natAbs * c_v.val.natAbs ≤ 3328 * 3328 := + Nat.mul_le_mul h_aj h_cv + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow aj32 c32 h_bnd + rw [this, h_aj32_val, h_c32_val] + -- ===== Step: ai_bi_aj_bj = wrapping_add ai_bi aj_bj_zeta ===== + set ai_bi_aj_bj : Std.I32 := Aeneas.Std.I32.wrapping_add ai_bi aj_bj_zeta with hsum_e_def + have h_sum_e_eq : CoreModels.core.num.I32.wrapping_add ai_bi aj_bj_zeta = .ok ai_bi_aj_bj := + L2_8c.cm_wrapping_add_i32_ok_eq ai_bi aj_bj_zeta + have h_sum_e_bnd : (ai_bi.val + aj_bj_zeta.val).natAbs ≤ 2 * 3328 * 3328 := by + rw [h_ai_bi_val, h_aj_bj_zeta_val] + have h_e1 : (ai_v.val * bi_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_ai h_bi + have h_e2 : (aj_v.val * c_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_aj h_cv + have h_tri : ((ai_v.val * bi_v.val) + (aj_v.val * c_v.val)).natAbs + ≤ (ai_v.val * bi_v.val).natAbs + (aj_v.val * c_v.val).natAbs := + Int.natAbs_add_le _ _ + omega + have h_sum_e_val : ai_bi_aj_bj.val = ai_bi.val + aj_bj_zeta.val := by + have h_bnd : (ai_bi.val + aj_bj_zeta.val).natAbs < 2^31 := by + have : (2 * 3328 * 3328 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow ai_bi aj_bj_zeta h_bnd + have h_delta_e_bnd : ai_bi_aj_bj.val.natAbs ≤ 2^25 := by + rw [h_sum_e_val] + have : (2 * 3328 * 3328 : Nat) ≤ 2^25 := by decide + omega + -- ===== Step: ai_bj = wrapping_mul ai32 bj32 ===== + set ai_bj_p : Std.I32 := Aeneas.Std.I32.wrapping_mul ai32 bj32 with haibj_def + have h_ai_bj_eq : CoreModels.core.num.I32.wrapping_mul ai32 bj32 = .ok ai_bj_p := + L2_8c.cm_wrapping_mul_i32_ok_eq ai32 bj32 + have h_ai_bj_val : ai_bj_p.val = ai_v.val * bj_v.val := by + have h_bnd : (ai32.val * bj32.val).natAbs < 2^31 := by + rw [h_ai32_val, h_bj32_val, Int.natAbs_mul] + have h_mul : ai_v.val.natAbs * bj_v.val.natAbs ≤ 3328 * 3328 := + Nat.mul_le_mul h_ai h_bj + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow ai32 bj32 h_bnd + rw [this, h_ai32_val, h_bj32_val] + -- ===== Step: aj_bi = wrapping_mul aj32 bi32 ===== + set aj_bi_p : Std.I32 := Aeneas.Std.I32.wrapping_mul aj32 bi32 with hajbi_def + have h_aj_bi_eq : CoreModels.core.num.I32.wrapping_mul aj32 bi32 = .ok aj_bi_p := + L2_8c.cm_wrapping_mul_i32_ok_eq aj32 bi32 + have h_aj_bi_val : aj_bi_p.val = aj_v.val * bi_v.val := by + have h_bnd : (aj32.val * bi32.val).natAbs < 2^31 := by + rw [h_aj32_val, h_bi32_val, Int.natAbs_mul] + have h_mul : aj_v.val.natAbs * bi_v.val.natAbs ≤ 3328 * 3328 := + Nat.mul_le_mul h_aj h_bi + have : (3328 * 3328 : Nat) < 2^31 := by decide + omega + have := L2_8c.wrapping_mul_i32_no_overflow aj32 bi32 h_bnd + rw [this, h_aj32_val, h_bi32_val] + -- ===== Step: ai_bj_aj_bi = wrapping_add ai_bj aj_bi ===== + set ai_bj_aj_bi : Std.I32 := Aeneas.Std.I32.wrapping_add ai_bj_p aj_bi_p with hsum_o_def + have h_sum_o_eq : CoreModels.core.num.I32.wrapping_add ai_bj_p aj_bi_p = .ok ai_bj_aj_bi := + L2_8c.cm_wrapping_add_i32_ok_eq ai_bj_p aj_bi_p + have h_sum_o_bnd : (ai_bj_p.val + aj_bi_p.val).natAbs ≤ 2 * 3328 * 3328 := by + rw [h_ai_bj_val, h_aj_bi_val] + have h_e1 : (ai_v.val * bj_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_ai h_bj + have h_e2 : (aj_v.val * bi_v.val).natAbs ≤ 3328 * 3328 := by + rw [Int.natAbs_mul]; exact Nat.mul_le_mul h_aj h_bi + have h_tri := Int.natAbs_add_le (ai_v.val * bj_v.val) (aj_v.val * bi_v.val) + omega + have h_sum_o_val : ai_bj_aj_bi.val = ai_bj_p.val + aj_bi_p.val := by + have h_bnd : (ai_bj_p.val + aj_bi_p.val).natAbs < 2^31 := by + have : (2 * 3328 * 3328 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow ai_bj_p aj_bi_p h_bnd + have h_delta_o_bnd : ai_bj_aj_bi.val.natAbs ≤ 2^25 := by + rw [h_sum_o_val] + have : (2 * 3328 * 3328 : Nat) ≤ 2^25 := by decide + omega + -- ===== Slice reads + writes for `out` ===== + have h_read_old_e : Aeneas.Std.Slice.index_usize out i1 = .ok old_e := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq out i1 + (by rw [h_out_val_len, h_i1_val']; exact h_2i_lt) + rw [h, h_i1_val'] + set new_e : Std.I32 := Aeneas.Std.I32.wrapping_add old_e ai_bi_aj_bj with hne_def + have h_new_e_eq : CoreModels.core.num.I32.wrapping_add old_e ai_bi_aj_bj = .ok new_e := + L2_8c.cm_wrapping_add_i32_ok_eq old_e ai_bi_aj_bj + have h_new_e_val : new_e.val = old_e.val + ai_bi_aj_bj.val := by + have h_bnd : (old_e.val + ai_bi_aj_bj.val).natAbs < 2^31 := by + have h_tri := Int.natAbs_add_le old_e.val ai_bi_aj_bj.val + have : (2^30 + 2^25 + 2^25 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow old_e ai_bi_aj_bj h_bnd + have h_new_e_bnd : new_e.val.natAbs ≤ old_e.val.natAbs + 2^25 := by + rw [h_new_e_val] + have h_tri := Int.natAbs_add_le old_e.val ai_bi_aj_bj.val + omega + have h_upd_e : Aeneas.Std.Slice.update out i1 new_e = .ok (out.set i1 new_e) := by + have hT := Aeneas.Std.Slice.update_spec out i1 new_e (by rw [h_out_len, h_i1_val']; exact h_2i_lt) + obtain ⟨v', h_eq, h_v'⟩ := Aeneas.Std.WP.spec_imp_exists hT + rw [h_eq, h_v'] + set out1 : Aeneas.Std.Slice Std.I32 := out.set i1 new_e with hout1_def + have h_out1_len : out1.length = 16 := by simp [hout1_def]; exact h_out_len + have h_out1_val_len : out1.val.length = 16 := h_out1_len + have h_old_o_in_out1 : out1.val[i2.val]! = old_o := by + have h_set_val : out1.val = out.val.set i1.val new_e := by + simp [hout1_def, Aeneas.Std.Slice.set_val_eq] + have h_ne : 2 * i.val + 1 ≠ i1.val := by rw [h_i1_val']; omega + have h_lt : 2 * i.val + 1 < out.val.length := by rw [h_out_val_len]; exact h_2i1_lt + rw [h_set_val, h_i2_val', hoo_def] + have h_lt_set : 2 * i.val + 1 < (out.val.set i1.val new_e).length := by + rw [List.length_set]; exact h_lt + rw [getElem!_pos (out.val.set i1.val new_e) (2 * i.val + 1) h_lt_set] + rw [getElem!_pos out.val (2 * i.val + 1) h_lt] + rw [List.getElem_set_ne (Ne.symm h_ne)] + have h_read_old_o : Aeneas.Std.Slice.index_usize out1 i2 = .ok old_o := by + have h := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.slice_index_usize_ok_eq out1 i2 + (by rw [h_out1_val_len, h_i2_val']; exact h_2i1_lt) + rw [h, h_old_o_in_out1] + set new_o : Std.I32 := Aeneas.Std.I32.wrapping_add old_o ai_bj_aj_bi with hno_def + have h_new_o_eq : CoreModels.core.num.I32.wrapping_add old_o ai_bj_aj_bi = .ok new_o := + L2_8c.cm_wrapping_add_i32_ok_eq old_o ai_bj_aj_bi + have h_new_o_val : new_o.val = old_o.val + ai_bj_aj_bi.val := by + have h_bnd : (old_o.val + ai_bj_aj_bi.val).natAbs < 2^31 := by + have h_tri := Int.natAbs_add_le old_o.val ai_bj_aj_bi.val + have : (2^30 + 2^25 + 2^25 : Nat) < 2^31 := by decide + omega + exact L2_8c.wrapping_add_i32_no_overflow old_o ai_bj_aj_bi h_bnd + have h_new_o_bnd : new_o.val.natAbs ≤ old_o.val.natAbs + 2^25 := by + rw [h_new_o_val] + have h_tri := Int.natAbs_add_le old_o.val ai_bj_aj_bi.val + omega + have h_upd_o : Aeneas.Std.Slice.update out1 i2 new_o = .ok (out1.set i2 new_o) := by + have hT := Aeneas.Std.Slice.update_spec out1 i2 new_o + (by rw [h_out1_len, h_i2_val']; exact h_2i1_lt) + obtain ⟨v', h_eq, h_v'⟩ := Aeneas.Std.WP.spec_imp_exists hT + rw [h_eq, h_v'] + set out2 : Aeneas.Std.Slice Std.I32 := out1.set i2 new_o with hout2_def + -- ===== Compose monadic body ===== + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_use_cache + a b i out cache = .ok out2 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_use_cache + simp only [h_i1_eq, h_i2_eq, h_read_ai, h_read_bi, h_read_aj, h_read_bj, h_read_cv, + h_as_ai, h_as_bi, h_as_aj, h_as_bj, h_as_cv, + h_ai_bi_eq, h_aj_bj_zeta_eq, + h_sum_e_eq, h_ai_bj_eq, h_aj_bi_eq, h_sum_o_eq, + h_read_old_e, h_new_e_eq, h_upd_e, + h_read_old_o, h_new_o_eq, h_upd_o, + Aeneas.Std.bind_tc_ok] + apply triple_of_ok_fc h_body + -- ===== POST: 7-conjunct ===== + have h_out2_val : out2.val = (out.val.set i1.val new_e).set i2.val new_o := by + show ((out.set i1 new_e).set i2 new_o).val = _ + rw [Aeneas.Std.Slice.set_val_eq, Aeneas.Std.Slice.set_val_eq] + have h_out2_len : out2.length = 16 := by + show ((out.set i1 new_e).set i2 new_o).length = 16 + rw [Aeneas.Std.Slice.set_length, Aeneas.Std.Slice.set_length]; exact h_out_len + have h_out2_val_len : out2.val.length = 16 := h_out2_len + have h_out2_at_2i : out2.val[2 * i.val]! = new_e := by + rw [h_out2_val, ← h_i1_val'] + have h_lt_out : i1.val < out.val.length := by rw [h_out_val_len, h_i1_val']; exact h_2i_lt + have h_lt1 : i1.val < (out.val.set i1.val new_e).length := by + rw [List.length_set]; exact h_lt_out + have h_lt2 : i1.val < ((out.val.set i1.val new_e).set i2.val new_o).length := by + rw [List.length_set]; exact h_lt1 + rw [getElem!_pos ((out.val.set i1.val new_e).set i2.val new_o) i1.val h_lt2] + rw [List.getElem_set_ne (by rw [h_i2_val', h_i1_val']; omega)] + rw [List.getElem_set_self] + have h_out2_at_2i1 : out2.val[2 * i.val + 1]! = new_o := by + rw [h_out2_val, ← h_i2_val'] + have h_lt_out : i2.val < out.val.length := by rw [h_out_val_len, h_i2_val']; exact h_2i1_lt + have h_lt1 : i2.val < (out.val.set i1.val new_e).length := by + rw [List.length_set]; exact h_lt_out + have h_lt2 : i2.val < ((out.val.set i1.val new_e).set i2.val new_o).length := by + rw [List.length_set]; exact h_lt1 + rw [getElem!_pos ((out.val.set i1.val new_e).set i2.val new_o) i2.val h_lt2] + rw [List.getElem_set_self] + have h_out2_untouched : ∀ k : Nat, k < 16 → k ≠ 2 * i.val → k ≠ 2 * i.val + 1 → + out2.val[k]! = out.val[k]! := by + intro k hk hki hkj + rw [h_out2_val] + have h_lt_out : k < out.val.length := by rw [h_out_val_len]; exact hk + have h_lt1 : k < (out.val.set i1.val new_e).length := by rw [List.length_set]; exact h_lt_out + have h_lt2 : k < ((out.val.set i1.val new_e).set i2.val new_o).length := by + rw [List.length_set]; exact h_lt1 + rw [getElem!_pos ((out.val.set i1.val new_e).set i2.val new_o) k h_lt2] + rw [getElem!_pos out.val k h_lt_out] + rw [List.getElem_set_ne (by rw [h_i2_val']; omega)] + rw [List.getElem_set_ne (by rw [h_i1_val']; omega)] + refine ⟨h_out2_len, h_out2_untouched, ?_, ?_, ?_, ?_⟩ + · rw [h_out2_at_2i] + rw [hoe_def] at h_new_e_bnd + exact h_new_e_bnd + · rw [h_out2_at_2i1] + rw [hoo_def] at h_new_o_bnd + exact h_new_o_bnd + · -- FE eq (even half) with symbolic cache lane. + rw [h_out2_at_2i, hoe_def] + apply L2_8d.mont_reduce_even_fe_eq_cache + (out := out.val[2 * i.val]!) (r := new_e) + (ai := ai_v) (bi := bi_v) (aj := aj_v) (c := c_v) + rw [← hoe_def, h_new_e_val, h_sum_e_val, h_ai_bi_val, h_aj_bj_zeta_val] + push_cast + ring + · -- FE eq (odd half) — same as L2.8c. + rw [h_out2_at_2i1, hoo_def] + apply L2_8d.mont_reduce_odd_fe_eq_cache + (out := out.val[2 * i.val + 1]!) (r := new_o) + (ai := ai_v) (bi := bi_v) (aj := aj_v) (bj := bj_v) + rw [← hoo_def, h_new_o_val, h_sum_o_val, h_ai_bj_val, h_aj_bi_val] + push_cast + ring + +set_option maxHeartbeats 16000000 in +/-- L2.8d — `vector.portable.ntt.accumulating_ntt_multiply_fill_cache`: + cache-filling variant. The impl chains 8 + `accumulating_ntt_multiply_binomials_fill_cache` calls; each + behaves like the base `_binomials` (per-pair degree-2 polynomial + multiply mod (X²−ζ²)) but additionally writes the Mont-reduced + `b[2i+1]·zeta_i` into `cache[i]`. + + POST shape mirrors `accumulating_ntt_multiply_fc` (length + relative + bound + `ntt_multiply_base_case_post` on the output slice) AND adds + a cache-side POST: each of the 8 cache slots stores the FE-projected + `mul_pure rhs[2i+1] zeta_eff_i` and is canonical; lanes 8..15 of + the cache are preserved from the input. + + Sibling adaptation of L2.8c reusing `L2_8c.*` infrastructure. -/ +@[spec] +theorem accumulating_ntt_multiply_fill_cache_fc + (lhs rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (out : Aeneas.Std.Slice Std.I32) + (cache : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 zeta2 zeta3 : Std.I16) + (h_out_len : out.length = 16) + (h_lhs : ∀ j : Fin 16, (lhs.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_rhs : ∀ j : Fin 16, (rhs.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_zeta0 : zeta0.val.natAbs ≤ 1664) + (h_zeta1 : zeta1.val.natAbs ≤ 1664) + (h_zeta2 : zeta2.val.natAbs ≤ 1664) + (h_zeta3 : zeta3.val.natAbs ≤ 1664) + (h_out_bnd : ∀ k : Fin 16, (out.val[k.val]!).val.natAbs ≤ 2^30) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_fill_cache + lhs rhs out cache zeta0 zeta1 zeta2 zeta3 + ⦃ ⇓ p => ⌜ p.1.length = 16 ∧ + (∀ k : Fin 16, (p.1.val[k.val]!).val.natAbs + ≤ (out.val[k.val]!).val.natAbs + 2^25) ∧ + ntt_multiply_base_case_post lhs rhs + zeta0 zeta1 zeta2 zeta3 out p.1 ∧ + Spec.ntt_multiply_cache_post rhs + zeta0 zeta1 zeta2 zeta3 p.2 ∧ + (∀ k : Nat, k < 16 → 8 ≤ k → + p.2.elements.val[k]! = cache.elements.val[k]!) ⌝ ⦄ := by + have h_zeta_within (z : Std.I16) (hz : z.val.natAbs ≤ 1664) : + z.val.natAbs ≤ 2^15 - 1 := by omega + have h_n0_val := L2_8c.wrapping_neg_val_eq zeta0 (h_zeta_within _ h_zeta0) + have h_n1_val := L2_8c.wrapping_neg_val_eq zeta1 (h_zeta_within _ h_zeta1) + have h_n2_val := L2_8c.wrapping_neg_val_eq zeta2 (h_zeta_within _ h_zeta2) + have h_n3_val := L2_8c.wrapping_neg_val_eq zeta3 (h_zeta_within _ h_zeta3) + set nzeta0 : Std.I16 := Aeneas.Std.I16.wrapping_sub (0#i16) zeta0 with hn0_def + set nzeta1 : Std.I16 := Aeneas.Std.I16.wrapping_sub (0#i16) zeta1 with hn1_def + set nzeta2 : Std.I16 := Aeneas.Std.I16.wrapping_sub (0#i16) zeta2 with hn2_def + set nzeta3 : Std.I16 := Aeneas.Std.I16.wrapping_sub (0#i16) zeta3 with hn3_def + have h_nz0_bnd : nzeta0.val.natAbs ≤ 1664 := by rw [h_n0_val]; omega + have h_nz1_bnd : nzeta1.val.natAbs ≤ 1664 := by rw [h_n1_val]; omega + have h_nz2_bnd : nzeta2.val.natAbs ≤ 1664 := by rw [h_n2_val]; omega + have h_nz3_bnd : nzeta3.val.natAbs ≤ 1664 := by rw [h_n3_val]; omega + have h_n0_fe : lift_fe_mont nzeta0 + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta0) := + L2_8c.lift_fe_mont_neg_pure_eq zeta0 nzeta0 (h_zeta_within _ h_zeta0) h_n0_val + have h_n1_fe : lift_fe_mont nzeta1 + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta1) := + L2_8c.lift_fe_mont_neg_pure_eq zeta1 nzeta1 (h_zeta_within _ h_zeta1) h_n1_val + have h_n2_fe : lift_fe_mont nzeta2 + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta2) := + L2_8c.lift_fe_mont_neg_pure_eq zeta2 nzeta2 (h_zeta_within _ h_zeta2) h_n2_val + have h_n3_fe : lift_fe_mont nzeta3 + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta3) := + L2_8c.lift_fe_mont_neg_pure_eq zeta3 nzeta3 (h_zeta_within _ h_zeta3) h_n3_val + have h_wn0 : core.num.I16.wrapping_neg zeta0 = .ok nzeta0 := + L2_8c.cm_wrapping_neg_i16_ok_eq zeta0 + have h_wn1 : core.num.I16.wrapping_neg zeta1 = .ok nzeta1 := + L2_8c.cm_wrapping_neg_i16_ok_eq zeta1 + have h_wn2 : core.num.I16.wrapping_neg zeta2 = .ok nzeta2 := + L2_8c.cm_wrapping_neg_i16_ok_eq zeta2 + have h_wn3 : core.num.I16.wrapping_neg zeta3 = .ok nzeta3 := + L2_8c.cm_wrapping_neg_i16_ok_eq zeta3 + have h_cz0 : libcrux_secrets.traits.Classify.Blanket.classify zeta0 = .ok zeta0 := + L2_8c.classify_ok_eq zeta0 + have h_cnz0 : libcrux_secrets.traits.Classify.Blanket.classify nzeta0 = .ok nzeta0 := + L2_8c.classify_ok_eq nzeta0 + have h_cz1 : libcrux_secrets.traits.Classify.Blanket.classify zeta1 = .ok zeta1 := + L2_8c.classify_ok_eq zeta1 + have h_cnz1 : libcrux_secrets.traits.Classify.Blanket.classify nzeta1 = .ok nzeta1 := + L2_8c.classify_ok_eq nzeta1 + have h_cz2 : libcrux_secrets.traits.Classify.Blanket.classify zeta2 = .ok zeta2 := + L2_8c.classify_ok_eq zeta2 + have h_cnz2 : libcrux_secrets.traits.Classify.Blanket.classify nzeta2 = .ok nzeta2 := + L2_8c.classify_ok_eq nzeta2 + have h_cz3 : libcrux_secrets.traits.Classify.Blanket.classify zeta3 = .ok zeta3 := + L2_8c.classify_ok_eq zeta3 + have h_cnz3 : libcrux_secrets.traits.Classify.Blanket.classify nzeta3 = .ok nzeta3 := + L2_8c.classify_ok_eq nzeta3 + have h_out_bnd_universal : ∀ k : Fin 16, (out.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := by + intro k; have := h_out_bnd k; omega + have h_cache_len : cache.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length cache + -- ===== 8 chained calls — each returns (r{i}, cache_out{i}) ===== + -- Call 0: pair 0 with zeta0 (touches lanes 0, 1; writes cache[0]). + obtain ⟨p0, h_p0_eq, h_r0_len, h_r0_unc, h_r0_bnd_e, h_r0_bnd_o, + h_r0_fe_e, h_r0_fe_o, h_c0_canon, h_c0_fe, h_c0_unc⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fill_cache_fc lhs rhs zeta0 0#usize out cache + (by decide) h_out_len h_lhs h_rhs h_zeta0 h_out_bnd_universal) + set r0 : Aeneas.Std.Slice Std.I32 := p0.1 with hr0_def + set cache0 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := p0.2 + with hc0_def + have h_r0_at_even : (r0.val[0]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (0#usize : Std.Usize).val : Nat) = 0 := by decide + have h_b := h_r0_bnd_e + rw [h_eq] at h_b + have h_out_le := h_out_bnd ⟨0, by decide⟩ + simp only at h_out_le; omega + have h_r0_at_odd : (r0.val[1]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (0#usize : Std.Usize).val + 1 : Nat) = 1 := by decide + have h_b := h_r0_bnd_o + rw [h_eq] at h_b + have h_out_le := h_out_bnd ⟨1, by decide⟩ + simp only at h_out_le; omega + have h_r0_unc' : ∀ k : Nat, k < 16 → k ≠ 0 → k ≠ 1 → + r0.val[k]! = out.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (0#usize : Std.Usize).val : Nat) = 0 := by decide + have h_eq_o : (2 * (0#usize : Std.Usize).val + 1 : Nat) = 1 := by decide + apply h_r0_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_c0_unc' : ∀ k : Nat, k < 16 → k ≠ 0 → + cache0.elements.val[k]! = cache.elements.val[k]! := by + intro k hk hki + apply h_c0_unc k hk + show k ≠ (0#usize : Std.Usize).val; rw [show (0#usize : Std.Usize).val = 0 from rfl]; exact hki + have h_r0_bnd_universal : ∀ k : Fin 16, (r0.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step out r0 0 (by decide) h_out_bnd_universal + h_r0_unc' h_r0_at_even h_r0_at_odd + + -- Call 1: pair 1 with nzeta0 (touches lanes 2, 3; writes cache[1]). + obtain ⟨p1, h_p1_eq, h_r1_len, h_r1_unc, h_r1_bnd_e, h_r1_bnd_o, + h_r1_fe_e, h_r1_fe_o, h_c1_canon, h_c1_fe, h_c1_unc⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fill_cache_fc lhs rhs nzeta0 1#usize r0 cache0 + (by decide) h_r0_len h_lhs h_rhs h_nz0_bnd h_r0_bnd_universal) + set r1 : Aeneas.Std.Slice Std.I32 := p1.1 with hr1_def + set cache1 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := p1.2 + with hc1_def + have h_r1_at_even : (r1.val[2]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (1#usize : Std.Usize).val : Nat) = 2 := by decide + have h_b := h_r1_bnd_e + rw [h_eq] at h_b + have h_r0_eq2 : r0.val[2]! = out.val[2]! := h_r0_unc' 2 (by decide) (by decide) (by decide) + rw [h_r0_eq2] at h_b + have h_out_le := h_out_bnd ⟨2, by decide⟩ + simp only at h_out_le; omega + have h_r1_at_odd : (r1.val[3]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (1#usize : Std.Usize).val + 1 : Nat) = 3 := by decide + have h_b := h_r1_bnd_o + rw [h_eq] at h_b + have h_r0_eq3 : r0.val[3]! = out.val[3]! := h_r0_unc' 3 (by decide) (by decide) (by decide) + rw [h_r0_eq3] at h_b + have h_out_le := h_out_bnd ⟨3, by decide⟩ + simp only at h_out_le; omega + have h_r1_unc' : ∀ k : Nat, k < 16 → k ≠ 2 → k ≠ 3 → + r1.val[k]! = r0.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (1#usize : Std.Usize).val : Nat) = 2 := by decide + have h_eq_o : (2 * (1#usize : Std.Usize).val + 1 : Nat) = 3 := by decide + apply h_r1_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_c1_unc' : ∀ k : Nat, k < 16 → k ≠ 1 → + cache1.elements.val[k]! = cache0.elements.val[k]! := by + intro k hk hki + apply h_c1_unc k hk + show k ≠ (1#usize : Std.Usize).val; rw [show (1#usize : Std.Usize).val = 1 from rfl]; exact hki + have h_r1_bnd_universal : ∀ k : Fin 16, (r1.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r0 r1 1 (by decide) h_r0_bnd_universal + h_r1_unc' h_r1_at_even h_r1_at_odd + + -- Call 2: pair 2 with zeta1 (touches lanes 4, 5; writes cache[2]). + obtain ⟨p2, h_p2_eq, h_r2_len, h_r2_unc, h_r2_bnd_e, h_r2_bnd_o, + h_r2_fe_e, h_r2_fe_o, h_c2_canon, h_c2_fe, h_c2_unc⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fill_cache_fc lhs rhs zeta1 2#usize r1 cache1 + (by decide) h_r1_len h_lhs h_rhs h_zeta1 h_r1_bnd_universal) + set r2 : Aeneas.Std.Slice Std.I32 := p2.1 with hr2_def + set cache2 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := p2.2 + with hc2_def + have h_r2_at_even : (r2.val[4]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (2#usize : Std.Usize).val : Nat) = 4 := by decide + have h_b := h_r2_bnd_e + rw [h_eq] at h_b + have h_r1_eq4 : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r1_eq4] at h_b + have h_out_le := h_out_bnd ⟨4, by decide⟩ + simp only at h_out_le; omega + have h_r2_at_odd : (r2.val[5]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (2#usize : Std.Usize).val + 1 : Nat) = 5 := by decide + have h_b := h_r2_bnd_o + rw [h_eq] at h_b + have h_r1_eq5 : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r1_eq5] at h_b + have h_out_le := h_out_bnd ⟨5, by decide⟩ + simp only at h_out_le; omega + have h_r2_unc' : ∀ k : Nat, k < 16 → k ≠ 4 → k ≠ 5 → + r2.val[k]! = r1.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (2#usize : Std.Usize).val : Nat) = 4 := by decide + have h_eq_o : (2 * (2#usize : Std.Usize).val + 1 : Nat) = 5 := by decide + apply h_r2_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_c2_unc' : ∀ k : Nat, k < 16 → k ≠ 2 → + cache2.elements.val[k]! = cache1.elements.val[k]! := by + intro k hk hki + apply h_c2_unc k hk + show k ≠ (2#usize : Std.Usize).val; rw [show (2#usize : Std.Usize).val = 2 from rfl]; exact hki + have h_r2_bnd_universal : ∀ k : Fin 16, (r2.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r1 r2 2 (by decide) h_r1_bnd_universal + h_r2_unc' h_r2_at_even h_r2_at_odd + + -- Call 3: pair 3 with nzeta1 (touches lanes 6, 7; writes cache[3]). + obtain ⟨p3, h_p3_eq, h_r3_len, h_r3_unc, h_r3_bnd_e, h_r3_bnd_o, + h_r3_fe_e, h_r3_fe_o, h_c3_canon, h_c3_fe, h_c3_unc⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fill_cache_fc lhs rhs nzeta1 3#usize r2 cache2 + (by decide) h_r2_len h_lhs h_rhs h_nz1_bnd h_r2_bnd_universal) + set r3 : Aeneas.Std.Slice Std.I32 := p3.1 with hr3_def + set cache3 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := p3.2 + with hc3_def + have h_r3_at_even : (r3.val[6]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (3#usize : Std.Usize).val : Nat) = 6 := by decide + have h_b := h_r3_bnd_e + rw [h_eq] at h_b + have h_r2_eq6 : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r2_eq6] at h_b + have h_out_le := h_out_bnd ⟨6, by decide⟩ + simp only at h_out_le; omega + have h_r3_at_odd : (r3.val[7]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (3#usize : Std.Usize).val + 1 : Nat) = 7 := by decide + have h_b := h_r3_bnd_o + rw [h_eq] at h_b + have h_r2_eq7 : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r2_eq7] at h_b + have h_out_le := h_out_bnd ⟨7, by decide⟩ + simp only at h_out_le; omega + have h_r3_unc' : ∀ k : Nat, k < 16 → k ≠ 6 → k ≠ 7 → + r3.val[k]! = r2.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (3#usize : Std.Usize).val : Nat) = 6 := by decide + have h_eq_o : (2 * (3#usize : Std.Usize).val + 1 : Nat) = 7 := by decide + apply h_r3_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_c3_unc' : ∀ k : Nat, k < 16 → k ≠ 3 → + cache3.elements.val[k]! = cache2.elements.val[k]! := by + intro k hk hki + apply h_c3_unc k hk + show k ≠ (3#usize : Std.Usize).val; rw [show (3#usize : Std.Usize).val = 3 from rfl]; exact hki + have h_r3_bnd_universal : ∀ k : Fin 16, (r3.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r2 r3 3 (by decide) h_r2_bnd_universal + h_r3_unc' h_r3_at_even h_r3_at_odd + + -- Call 4: pair 4 with zeta2 (touches lanes 8, 9; writes cache[4]). + obtain ⟨p4, h_p4_eq, h_r4_len, h_r4_unc, h_r4_bnd_e, h_r4_bnd_o, + h_r4_fe_e, h_r4_fe_o, h_c4_canon, h_c4_fe, h_c4_unc⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fill_cache_fc lhs rhs zeta2 4#usize r3 cache3 + (by decide) h_r3_len h_lhs h_rhs h_zeta2 h_r3_bnd_universal) + set r4 : Aeneas.Std.Slice Std.I32 := p4.1 with hr4_def + set cache4 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := p4.2 + with hc4_def + have h_r4_at_even : (r4.val[8]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (4#usize : Std.Usize).val : Nat) = 8 := by decide + have h_b := h_r4_bnd_e + rw [h_eq] at h_b + have h_r3_eq8 : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r3_eq8] at h_b + have h_out_le := h_out_bnd ⟨8, by decide⟩ + simp only at h_out_le; omega + have h_r4_at_odd : (r4.val[9]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (4#usize : Std.Usize).val + 1 : Nat) = 9 := by decide + have h_b := h_r4_bnd_o + rw [h_eq] at h_b + have h_r3_eq9 : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r3_eq9] at h_b + have h_out_le := h_out_bnd ⟨9, by decide⟩ + simp only at h_out_le; omega + have h_r4_unc' : ∀ k : Nat, k < 16 → k ≠ 8 → k ≠ 9 → + r4.val[k]! = r3.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (4#usize : Std.Usize).val : Nat) = 8 := by decide + have h_eq_o : (2 * (4#usize : Std.Usize).val + 1 : Nat) = 9 := by decide + apply h_r4_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_c4_unc' : ∀ k : Nat, k < 16 → k ≠ 4 → + cache4.elements.val[k]! = cache3.elements.val[k]! := by + intro k hk hki + apply h_c4_unc k hk + show k ≠ (4#usize : Std.Usize).val; rw [show (4#usize : Std.Usize).val = 4 from rfl]; exact hki + have h_r4_bnd_universal : ∀ k : Fin 16, (r4.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r3 r4 4 (by decide) h_r3_bnd_universal + h_r4_unc' h_r4_at_even h_r4_at_odd + + -- Call 5: pair 5 with nzeta2 (touches lanes 10, 11; writes cache[5]). + obtain ⟨p5, h_p5_eq, h_r5_len, h_r5_unc, h_r5_bnd_e, h_r5_bnd_o, + h_r5_fe_e, h_r5_fe_o, h_c5_canon, h_c5_fe, h_c5_unc⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fill_cache_fc lhs rhs nzeta2 5#usize r4 cache4 + (by decide) h_r4_len h_lhs h_rhs h_nz2_bnd h_r4_bnd_universal) + set r5 : Aeneas.Std.Slice Std.I32 := p5.1 with hr5_def + set cache5 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := p5.2 + with hc5_def + have h_r5_at_even : (r5.val[10]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (5#usize : Std.Usize).val : Nat) = 10 := by decide + have h_b := h_r5_bnd_e + rw [h_eq] at h_b + have h_r4_eq10 : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r4_eq10] at h_b + have h_out_le := h_out_bnd ⟨10, by decide⟩ + simp only at h_out_le; omega + have h_r5_at_odd : (r5.val[11]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (5#usize : Std.Usize).val + 1 : Nat) = 11 := by decide + have h_b := h_r5_bnd_o + rw [h_eq] at h_b + have h_r4_eq11 : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r4_eq11] at h_b + have h_out_le := h_out_bnd ⟨11, by decide⟩ + simp only at h_out_le; omega + have h_r5_unc' : ∀ k : Nat, k < 16 → k ≠ 10 → k ≠ 11 → + r5.val[k]! = r4.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (5#usize : Std.Usize).val : Nat) = 10 := by decide + have h_eq_o : (2 * (5#usize : Std.Usize).val + 1 : Nat) = 11 := by decide + apply h_r5_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_c5_unc' : ∀ k : Nat, k < 16 → k ≠ 5 → + cache5.elements.val[k]! = cache4.elements.val[k]! := by + intro k hk hki + apply h_c5_unc k hk + show k ≠ (5#usize : Std.Usize).val; rw [show (5#usize : Std.Usize).val = 5 from rfl]; exact hki + have h_r5_bnd_universal : ∀ k : Fin 16, (r5.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r4 r5 5 (by decide) h_r4_bnd_universal + h_r5_unc' h_r5_at_even h_r5_at_odd + + -- Call 6: pair 6 with zeta3 (touches lanes 12, 13; writes cache[6]). + obtain ⟨p6, h_p6_eq, h_r6_len, h_r6_unc, h_r6_bnd_e, h_r6_bnd_o, + h_r6_fe_e, h_r6_fe_o, h_c6_canon, h_c6_fe, h_c6_unc⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fill_cache_fc lhs rhs zeta3 6#usize r5 cache5 + (by decide) h_r5_len h_lhs h_rhs h_zeta3 h_r5_bnd_universal) + set r6 : Aeneas.Std.Slice Std.I32 := p6.1 with hr6_def + set cache6 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := p6.2 + with hc6_def + have h_r6_at_even : (r6.val[12]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (6#usize : Std.Usize).val : Nat) = 12 := by decide + have h_b := h_r6_bnd_e + rw [h_eq] at h_b + have h_r5_eq12 : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r5_eq12] at h_b + have h_out_le := h_out_bnd ⟨12, by decide⟩ + simp only at h_out_le; omega + have h_r6_at_odd : (r6.val[13]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (6#usize : Std.Usize).val + 1 : Nat) = 13 := by decide + have h_b := h_r6_bnd_o + rw [h_eq] at h_b + have h_r5_eq13 : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r5_eq13] at h_b + have h_out_le := h_out_bnd ⟨13, by decide⟩ + simp only at h_out_le; omega + have h_r6_unc' : ∀ k : Nat, k < 16 → k ≠ 12 → k ≠ 13 → + r6.val[k]! = r5.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (6#usize : Std.Usize).val : Nat) = 12 := by decide + have h_eq_o : (2 * (6#usize : Std.Usize).val + 1 : Nat) = 13 := by decide + apply h_r6_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_c6_unc' : ∀ k : Nat, k < 16 → k ≠ 6 → + cache6.elements.val[k]! = cache5.elements.val[k]! := by + intro k hk hki + apply h_c6_unc k hk + show k ≠ (6#usize : Std.Usize).val; rw [show (6#usize : Std.Usize).val = 6 from rfl]; exact hki + have h_r6_bnd_universal : ∀ k : Fin 16, (r6.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r5 r6 6 (by decide) h_r5_bnd_universal + h_r6_unc' h_r6_at_even h_r6_at_odd + + -- Call 7: pair 7 with nzeta3 (touches lanes 14, 15; writes cache[7]). + obtain ⟨p7, h_p7_eq, h_r7_len, h_r7_unc, h_r7_bnd_e, h_r7_bnd_o, + h_r7_fe_e, h_r7_fe_o, h_c7_canon, h_c7_fe, h_c7_unc⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_fill_cache_fc lhs rhs nzeta3 7#usize r6 cache6 + (by decide) h_r6_len h_lhs h_rhs h_nz3_bnd h_r6_bnd_universal) + set r7 : Aeneas.Std.Slice Std.I32 := p7.1 with hr7_def + set cache7 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := p7.2 + with hc7_def + have h_r7_at_even : (r7.val[14]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (7#usize : Std.Usize).val : Nat) = 14 := by decide + have h_b := h_r7_bnd_e + rw [h_eq] at h_b + have h_r6_eq14 : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r6_eq14] at h_b + have h_out_le := h_out_bnd ⟨14, by decide⟩ + simp only at h_out_le; omega + have h_r7_at_odd : (r7.val[15]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (7#usize : Std.Usize).val + 1 : Nat) = 15 := by decide + have h_b := h_r7_bnd_o + rw [h_eq] at h_b + have h_r6_eq15 : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r6_eq15] at h_b + have h_out_le := h_out_bnd ⟨15, by decide⟩ + simp only at h_out_le; omega + have h_r7_unc' : ∀ k : Nat, k < 16 → k ≠ 14 → k ≠ 15 → + r7.val[k]! = r6.val[k]! := by + intro k hk hke hko + have h_eq_e : (2 * (7#usize : Std.Usize).val : Nat) = 14 := by decide + have h_eq_o : (2 * (7#usize : Std.Usize).val + 1 : Nat) = 15 := by decide + apply h_r7_unc k hk + · rw [h_eq_e]; exact hke + · rw [h_eq_o]; exact hko + have h_c7_unc' : ∀ k : Nat, k < 16 → k ≠ 7 → + cache7.elements.val[k]! = cache6.elements.val[k]! := by + intro k hk hki + apply h_c7_unc k hk + show k ≠ (7#usize : Std.Usize).val; rw [show (7#usize : Std.Usize).val = 7 from rfl]; exact hki + + -- Compose the monadic body. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_fill_cache + lhs rhs out cache zeta0 zeta1 zeta2 zeta3 = .ok (r7, cache7) := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_fill_cache + -- The result of each call is `pK = (rK, cacheK)`. The impl binds these as + -- `let (outK, cacheK) ← ...`, which expects a destructured pattern. + -- Convert the .ok pK into the destructured form via pair eta. + have h_p0_eq' : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs zeta0 0#usize out cache = .ok (r0, cache0) := by + rw [h_p0_eq] + have h_p1_eq' : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs nzeta0 1#usize r0 cache0 = .ok (r1, cache1) := by + rw [h_p1_eq] + have h_p2_eq' : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs zeta1 2#usize r1 cache1 = .ok (r2, cache2) := by + rw [h_p2_eq] + have h_p3_eq' : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs nzeta1 3#usize r2 cache2 = .ok (r3, cache3) := by + rw [h_p3_eq] + have h_p4_eq' : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs zeta2 4#usize r3 cache3 = .ok (r4, cache4) := by + rw [h_p4_eq] + have h_p5_eq' : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs nzeta2 5#usize r4 cache4 = .ok (r5, cache5) := by + rw [h_p5_eq] + have h_p6_eq' : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs zeta3 6#usize r5 cache5 = .ok (r6, cache6) := by + rw [h_p6_eq] + have h_p7_eq' : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs nzeta3 7#usize r6 cache6 = .ok (r7, cache7) := by + rw [h_p7_eq] + rw [h_wn0]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_wn1]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_wn2]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_wn3]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_cz0]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_p0_eq'] + simp only [Aeneas.Std.bind_tc_ok] + change (do + let i1 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta0 + let (out2, cache2) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i1 1#usize r0 cache0 + let i2 ← libcrux_secrets.traits.Classify.Blanket.classify zeta1 + let (out3, cache3) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i2 2#usize out2 cache2 + let i3 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta1 + let (out4, cache4) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i3 3#usize out3 cache3 + let i4 ← libcrux_secrets.traits.Classify.Blanket.classify zeta2 + let (out5, cache5) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i4 4#usize out4 cache4 + let i5 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta2 + let (out6, cache6) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i5 5#usize out5 cache5 + let i6 ← libcrux_secrets.traits.Classify.Blanket.classify zeta3 + let (out7, cache7) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i6 6#usize out6 cache6 + let i7 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta3 + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i7 7#usize out7 cache7) = .ok (r7, cache7) + rw [h_cnz0]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_p1_eq'] + simp only [Aeneas.Std.bind_tc_ok] + change (do + let i2 ← libcrux_secrets.traits.Classify.Blanket.classify zeta1 + let (out3, cache3) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i2 2#usize r1 cache1 + let i3 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta1 + let (out4, cache4) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i3 3#usize out3 cache3 + let i4 ← libcrux_secrets.traits.Classify.Blanket.classify zeta2 + let (out5, cache5) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i4 4#usize out4 cache4 + let i5 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta2 + let (out6, cache6) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i5 5#usize out5 cache5 + let i6 ← libcrux_secrets.traits.Classify.Blanket.classify zeta3 + let (out7, cache7) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i6 6#usize out6 cache6 + let i7 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta3 + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i7 7#usize out7 cache7) = .ok (r7, cache7) + rw [h_cz1]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_p2_eq'] + simp only [Aeneas.Std.bind_tc_ok] + change (do + let i3 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta1 + let (out4, cache4) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i3 3#usize r2 cache2 + let i4 ← libcrux_secrets.traits.Classify.Blanket.classify zeta2 + let (out5, cache5) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i4 4#usize out4 cache4 + let i5 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta2 + let (out6, cache6) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i5 5#usize out5 cache5 + let i6 ← libcrux_secrets.traits.Classify.Blanket.classify zeta3 + let (out7, cache7) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i6 6#usize out6 cache6 + let i7 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta3 + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i7 7#usize out7 cache7) = .ok (r7, cache7) + rw [h_cnz1]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_p3_eq'] + simp only [Aeneas.Std.bind_tc_ok] + change (do + let i4 ← libcrux_secrets.traits.Classify.Blanket.classify zeta2 + let (out5, cache5) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i4 4#usize r3 cache3 + let i5 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta2 + let (out6, cache6) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i5 5#usize out5 cache5 + let i6 ← libcrux_secrets.traits.Classify.Blanket.classify zeta3 + let (out7, cache7) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i6 6#usize out6 cache6 + let i7 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta3 + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i7 7#usize out7 cache7) = .ok (r7, cache7) + rw [h_cz2]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_p4_eq'] + simp only [Aeneas.Std.bind_tc_ok] + change (do + let i5 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta2 + let (out6, cache6) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i5 5#usize r4 cache4 + let i6 ← libcrux_secrets.traits.Classify.Blanket.classify zeta3 + let (out7, cache7) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i6 6#usize out6 cache6 + let i7 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta3 + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i7 7#usize out7 cache7) = .ok (r7, cache7) + rw [h_cnz2]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_p5_eq'] + simp only [Aeneas.Std.bind_tc_ok] + change (do + let i6 ← libcrux_secrets.traits.Classify.Blanket.classify zeta3 + let (out7, cache7) ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i6 6#usize r5 cache5 + let i7 ← libcrux_secrets.traits.Classify.Blanket.classify nzeta3 + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_binomials_fill_cache + lhs rhs i7 7#usize out7 cache7) = .ok (r7, cache7) + rw [h_cz3]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_p6_eq'] + simp only [Aeneas.Std.bind_tc_ok] + rw [h_cnz3]; simp only [Aeneas.Std.bind_tc_ok] + exact h_p7_eq' + apply triple_of_ok_fc h_body + -- ===== POST: 5-conjunct ===== + refine ⟨h_r7_len, ?_, ?_, ?_, ?_⟩ + · -- Relative bound: ∀ k, r7.val[k]!.natAbs ≤ out.val[k]!.natAbs + 2^25. + intro k + rcases k with ⟨k, hk⟩ + interval_cases k + · have h_r7_at_0 : r7.val[0]! = r0.val[0]! := by + rw [h_r7_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r7_at_0] + have h_eq : (2 * (0#usize : Std.Usize).val : Nat) = 0 := by decide + have h_b := h_r0_bnd_e + rw [h_eq] at h_b + exact h_b + · have h_r7_at_1 : r7.val[1]! = r0.val[1]! := by + rw [h_r7_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r7_at_1] + have h_eq : (2 * (0#usize : Std.Usize).val + 1 : Nat) = 1 := by decide + have h_b := h_r0_bnd_o + rw [h_eq] at h_b + exact h_b + · have h_r7_at_2 : r7.val[2]! = r1.val[2]! := by + rw [h_r7_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r7_at_2] + have h_eq : (2 * (1#usize : Std.Usize).val : Nat) = 2 := by decide + have h_b := h_r1_bnd_e + rw [h_eq] at h_b + have h_r0_at_2 : r0.val[2]! = out.val[2]! := by + rw [h_r0_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r0_at_2] at h_b + exact h_b + · have h_r7_at_3 : r7.val[3]! = r1.val[3]! := by + rw [h_r7_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r7_at_3] + have h_eq : (2 * (1#usize : Std.Usize).val + 1 : Nat) = 3 := by decide + have h_b := h_r1_bnd_o + rw [h_eq] at h_b + have h_r0_at_3 : r0.val[3]! = out.val[3]! := by + rw [h_r0_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r0_at_3] at h_b + exact h_b + · have h_r7_at_4 : r7.val[4]! = r2.val[4]! := by + rw [h_r7_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r7_at_4] + have h_eq : (2 * (2#usize : Std.Usize).val : Nat) = 4 := by decide + have h_b := h_r2_bnd_e + rw [h_eq] at h_b + have h_r1_at_4 : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r1_at_4] at h_b + exact h_b + · have h_r7_at_5 : r7.val[5]! = r2.val[5]! := by + rw [h_r7_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r7_at_5] + have h_eq : (2 * (2#usize : Std.Usize).val + 1 : Nat) = 5 := by decide + have h_b := h_r2_bnd_o + rw [h_eq] at h_b + have h_r1_at_5 : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r1_at_5] at h_b + exact h_b + · have h_r7_at_6 : r7.val[6]! = r3.val[6]! := by + rw [h_r7_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r7_at_6] + have h_eq : (2 * (3#usize : Std.Usize).val : Nat) = 6 := by decide + have h_b := h_r3_bnd_e + rw [h_eq] at h_b + have h_r2_at_6 : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r2_at_6] at h_b + exact h_b + · have h_r7_at_7 : r7.val[7]! = r3.val[7]! := by + rw [h_r7_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r7_at_7] + have h_eq : (2 * (3#usize : Std.Usize).val + 1 : Nat) = 7 := by decide + have h_b := h_r3_bnd_o + rw [h_eq] at h_b + have h_r2_at_7 : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r2_at_7] at h_b + exact h_b + · have h_r7_at_8 : r7.val[8]! = r4.val[8]! := by + rw [h_r7_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r7_at_8] + have h_eq : (2 * (4#usize : Std.Usize).val : Nat) = 8 := by decide + have h_b := h_r4_bnd_e + rw [h_eq] at h_b + have h_r3_at_8 : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r3_at_8] at h_b + exact h_b + · have h_r7_at_9 : r7.val[9]! = r4.val[9]! := by + rw [h_r7_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r7_at_9] + have h_eq : (2 * (4#usize : Std.Usize).val + 1 : Nat) = 9 := by decide + have h_b := h_r4_bnd_o + rw [h_eq] at h_b + have h_r3_at_9 : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r3_at_9] at h_b + exact h_b + · have h_r7_at_10 : r7.val[10]! = r5.val[10]! := by + rw [h_r7_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r7_at_10] + have h_eq : (2 * (5#usize : Std.Usize).val : Nat) = 10 := by decide + have h_b := h_r5_bnd_e + rw [h_eq] at h_b + have h_r4_at_10 : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r4_at_10] at h_b + exact h_b + · have h_r7_at_11 : r7.val[11]! = r5.val[11]! := by + rw [h_r7_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r7_at_11] + have h_eq : (2 * (5#usize : Std.Usize).val + 1 : Nat) = 11 := by decide + have h_b := h_r5_bnd_o + rw [h_eq] at h_b + have h_r4_at_11 : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r4_at_11] at h_b + exact h_b + · have h_r7_at_12 : r7.val[12]! = r6.val[12]! := by + rw [h_r7_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r7_at_12] + have h_eq : (2 * (6#usize : Std.Usize).val : Nat) = 12 := by decide + have h_b := h_r6_bnd_e + rw [h_eq] at h_b + have h_r5_at_12 : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r5_at_12] at h_b + exact h_b + · have h_r7_at_13 : r7.val[13]! = r6.val[13]! := by + rw [h_r7_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r7_at_13] + have h_eq : (2 * (6#usize : Std.Usize).val + 1 : Nat) = 13 := by decide + have h_b := h_r6_bnd_o + rw [h_eq] at h_b + have h_r5_at_13 : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r5_at_13] at h_b + exact h_b + · have h_eq : (2 * (7#usize : Std.Usize).val : Nat) = 14 := by decide + have h_b := h_r7_bnd_e + rw [h_eq] at h_b + have h_r6_at_14 : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r6_at_14] at h_b + exact h_b + · have h_eq : (2 * (7#usize : Std.Usize).val + 1 : Nat) = 15 := by decide + have h_b := h_r7_bnd_o + rw [h_eq] at h_b + have h_r6_at_15 : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r6_at_15] at h_b + exact h_b + · -- ntt_multiply_base_case_post: per-lane FE equation. + unfold ntt_multiply_base_case_post ntt_multiply_base_case_alg + apply Subtype.ext + have h_lhs_val : (Spec.chunk_reducing_from_i32_array_pure r7).val + = (List.range 16).map (fun i => Spec.mont_reduce_pure (lift_fe_int (r7.val[i]!).val)) := by + unfold Spec.chunk_reducing_from_i32_array_pure; rfl + have h_rhs_val : (Spec.chunk_add_pure + (Spec.chunk_reducing_from_i32_array_pure out) + (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3))).val + = (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.chunk_reducing_from_i32_array_pure out).val[i]!) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[i]!)) := by + unfold Spec.chunk_add_pure; rfl + rw [h_lhs_val, h_rhs_val] + apply List.ext_getElem + · simp + · intro k hk1 hk2 + have hk : k < 16 := by simp at hk1; exact hk1 + rw [List.getElem_map, List.getElem_map, List.getElem_range] + interval_cases k + · -- Lane 0: touched by call 0 (zeta0, even). + have h_r7_at_lane : r7.val[0]! = r0.val[0]! := by + rw [h_r7_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_fe := h_r0_fe_e + simp only [ + show (2 * (0#usize : Std.Usize).val : Nat) = 0 from by decide] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[0]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[0]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[0]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[0]!) + ((lift_chunk_mont rhs).val[0]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[1]!) + ((lift_chunk_mont rhs).val[1]!)) + (lift_fe_mont zeta0)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_0 : (lift_chunk_mont lhs).val[0]! + = lift_fe_mont (lhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_lhs_1 : (lift_chunk_mont lhs).val[1]! + = lift_fe_mont (lhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 1 (by rw [h_l]; decide)] + have h_lcm_rhs_0 : (lift_chunk_mont rhs).val[0]! + = lift_fe_mont (rhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_rhs_1 : (lift_chunk_mont rhs).val[1]! + = lift_fe_mont (rhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 1 (by rw [h_l]; decide)] + rw [h_lcm_lhs_0, h_lcm_lhs_1, h_lcm_rhs_0, h_lcm_rhs_1] + · -- Lane 1: touched by call 0 (zeta0, odd). + have h_r7_at_lane : r7.val[1]! = r0.val[1]! := by + rw [h_r7_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_fe := h_r0_fe_o + simp only [ + show (2 * (0#usize : Std.Usize).val : Nat) = 0 from by decide] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[1]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[1]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[1]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[0]!) + ((lift_chunk_mont rhs).val[1]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[1]!) + ((lift_chunk_mont rhs).val[0]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_0 : (lift_chunk_mont lhs).val[0]! + = lift_fe_mont (lhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_lhs_1 : (lift_chunk_mont lhs).val[1]! + = lift_fe_mont (lhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 1 (by rw [h_l]; decide)] + have h_lcm_rhs_0 : (lift_chunk_mont rhs).val[0]! + = lift_fe_mont (rhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_rhs_1 : (lift_chunk_mont rhs).val[1]! + = lift_fe_mont (rhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 1 (by rw [h_l]; decide)] + rw [h_lcm_lhs_0, h_lcm_lhs_1, h_lcm_rhs_0, h_lcm_rhs_1] + · -- Lane 2: touched by call 1 (nzeta0, even). + have h_r7_at_lane : r7.val[2]! = r1.val[2]! := by + rw [h_r7_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r0.val[2]! = out.val[2]! := by + rw [h_r0_unc' 2 (by decide) (by decide) (by decide)] + have h_src_at_odd : r0.val[3]! = out.val[3]! := by + rw [h_r0_unc' 3 (by decide) (by decide) (by decide)] + have h_fe := h_r1_fe_e + simp only [ + show (2 * (1#usize : Std.Usize).val : Nat) = 2 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n0_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[2]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[2]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[2]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[2]!) + ((lift_chunk_mont rhs).val[2]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[3]!) + ((lift_chunk_mont rhs).val[3]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta0))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_2 : (lift_chunk_mont lhs).val[2]! + = lift_fe_mont (lhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_lhs_3 : (lift_chunk_mont lhs).val[3]! + = lift_fe_mont (lhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 3 (by rw [h_l]; decide)] + have h_lcm_rhs_2 : (lift_chunk_mont rhs).val[2]! + = lift_fe_mont (rhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_rhs_3 : (lift_chunk_mont rhs).val[3]! + = lift_fe_mont (rhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 3 (by rw [h_l]; decide)] + rw [h_lcm_lhs_2, h_lcm_lhs_3, h_lcm_rhs_2, h_lcm_rhs_3] + · -- Lane 3: touched by call 1 (nzeta0, odd). + have h_r7_at_lane : r7.val[3]! = r1.val[3]! := by + rw [h_r7_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r0.val[2]! = out.val[2]! := by + rw [h_r0_unc' 2 (by decide) (by decide) (by decide)] + have h_src_at_odd : r0.val[3]! = out.val[3]! := by + rw [h_r0_unc' 3 (by decide) (by decide) (by decide)] + have h_fe := h_r1_fe_o + simp only [ + show (2 * (1#usize : Std.Usize).val : Nat) = 2 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[3]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[3]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[3]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[2]!) + ((lift_chunk_mont rhs).val[3]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[3]!) + ((lift_chunk_mont rhs).val[2]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_2 : (lift_chunk_mont lhs).val[2]! + = lift_fe_mont (lhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_lhs_3 : (lift_chunk_mont lhs).val[3]! + = lift_fe_mont (lhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 3 (by rw [h_l]; decide)] + have h_lcm_rhs_2 : (lift_chunk_mont rhs).val[2]! + = lift_fe_mont (rhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_rhs_3 : (lift_chunk_mont rhs).val[3]! + = lift_fe_mont (rhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 3 (by rw [h_l]; decide)] + rw [h_lcm_lhs_2, h_lcm_lhs_3, h_lcm_rhs_2, h_lcm_rhs_3] + · -- Lane 4: touched by call 2 (zeta1, even). + have h_r7_at_lane : r7.val[4]! = r2.val[4]! := by + rw [h_r7_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + have h_src_at_odd : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + have h_fe := h_r2_fe_e + simp only [ + show (2 * (2#usize : Std.Usize).val : Nat) = 4 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[4]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[4]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[4]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[4]!) + ((lift_chunk_mont rhs).val[4]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[5]!) + ((lift_chunk_mont rhs).val[5]!)) + (lift_fe_mont zeta1)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_4 : (lift_chunk_mont lhs).val[4]! + = lift_fe_mont (lhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_lhs_5 : (lift_chunk_mont lhs).val[5]! + = lift_fe_mont (lhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 5 (by rw [h_l]; decide)] + have h_lcm_rhs_4 : (lift_chunk_mont rhs).val[4]! + = lift_fe_mont (rhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_rhs_5 : (lift_chunk_mont rhs).val[5]! + = lift_fe_mont (rhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 5 (by rw [h_l]; decide)] + rw [h_lcm_lhs_4, h_lcm_lhs_5, h_lcm_rhs_4, h_lcm_rhs_5] + · -- Lane 5: touched by call 2 (zeta1, odd). + have h_r7_at_lane : r7.val[5]! = r2.val[5]! := by + rw [h_r7_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + have h_src_at_odd : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + have h_fe := h_r2_fe_o + simp only [ + show (2 * (2#usize : Std.Usize).val : Nat) = 4 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[5]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[5]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[5]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[4]!) + ((lift_chunk_mont rhs).val[5]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[5]!) + ((lift_chunk_mont rhs).val[4]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_4 : (lift_chunk_mont lhs).val[4]! + = lift_fe_mont (lhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_lhs_5 : (lift_chunk_mont lhs).val[5]! + = lift_fe_mont (lhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 5 (by rw [h_l]; decide)] + have h_lcm_rhs_4 : (lift_chunk_mont rhs).val[4]! + = lift_fe_mont (rhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_rhs_5 : (lift_chunk_mont rhs).val[5]! + = lift_fe_mont (rhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 5 (by rw [h_l]; decide)] + rw [h_lcm_lhs_4, h_lcm_lhs_5, h_lcm_rhs_4, h_lcm_rhs_5] + · -- Lane 6: touched by call 3 (nzeta1, even). + have h_r7_at_lane : r7.val[6]! = r3.val[6]! := by + rw [h_r7_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + have h_src_at_odd : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + have h_fe := h_r3_fe_e + simp only [ + show (2 * (3#usize : Std.Usize).val : Nat) = 6 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n1_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[6]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[6]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[6]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[6]!) + ((lift_chunk_mont rhs).val[6]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[7]!) + ((lift_chunk_mont rhs).val[7]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta1))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_6 : (lift_chunk_mont lhs).val[6]! + = lift_fe_mont (lhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_lhs_7 : (lift_chunk_mont lhs).val[7]! + = lift_fe_mont (lhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 7 (by rw [h_l]; decide)] + have h_lcm_rhs_6 : (lift_chunk_mont rhs).val[6]! + = lift_fe_mont (rhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_rhs_7 : (lift_chunk_mont rhs).val[7]! + = lift_fe_mont (rhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 7 (by rw [h_l]; decide)] + rw [h_lcm_lhs_6, h_lcm_lhs_7, h_lcm_rhs_6, h_lcm_rhs_7] + · -- Lane 7: touched by call 3 (nzeta1, odd). + have h_r7_at_lane : r7.val[7]! = r3.val[7]! := by + rw [h_r7_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + have h_src_at_odd : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + have h_fe := h_r3_fe_o + simp only [ + show (2 * (3#usize : Std.Usize).val : Nat) = 6 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[7]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[7]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[7]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[6]!) + ((lift_chunk_mont rhs).val[7]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[7]!) + ((lift_chunk_mont rhs).val[6]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_6 : (lift_chunk_mont lhs).val[6]! + = lift_fe_mont (lhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_lhs_7 : (lift_chunk_mont lhs).val[7]! + = lift_fe_mont (lhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 7 (by rw [h_l]; decide)] + have h_lcm_rhs_6 : (lift_chunk_mont rhs).val[6]! + = lift_fe_mont (rhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_rhs_7 : (lift_chunk_mont rhs).val[7]! + = lift_fe_mont (rhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 7 (by rw [h_l]; decide)] + rw [h_lcm_lhs_6, h_lcm_lhs_7, h_lcm_rhs_6, h_lcm_rhs_7] + · -- Lane 8: touched by call 4 (zeta2, even). + have h_r7_at_lane : r7.val[8]! = r4.val[8]! := by + rw [h_r7_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + have h_src_at_odd : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + have h_fe := h_r4_fe_e + simp only [ + show (2 * (4#usize : Std.Usize).val : Nat) = 8 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[8]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[8]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[8]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[8]!) + ((lift_chunk_mont rhs).val[8]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[9]!) + ((lift_chunk_mont rhs).val[9]!)) + (lift_fe_mont zeta2)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_8 : (lift_chunk_mont lhs).val[8]! + = lift_fe_mont (lhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_lhs_9 : (lift_chunk_mont lhs).val[9]! + = lift_fe_mont (lhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 9 (by rw [h_l]; decide)] + have h_lcm_rhs_8 : (lift_chunk_mont rhs).val[8]! + = lift_fe_mont (rhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_rhs_9 : (lift_chunk_mont rhs).val[9]! + = lift_fe_mont (rhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 9 (by rw [h_l]; decide)] + rw [h_lcm_lhs_8, h_lcm_lhs_9, h_lcm_rhs_8, h_lcm_rhs_9] + · -- Lane 9: touched by call 4 (zeta2, odd). + have h_r7_at_lane : r7.val[9]! = r4.val[9]! := by + rw [h_r7_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + have h_src_at_odd : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + have h_fe := h_r4_fe_o + simp only [ + show (2 * (4#usize : Std.Usize).val : Nat) = 8 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[9]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[9]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[9]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[8]!) + ((lift_chunk_mont rhs).val[9]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[9]!) + ((lift_chunk_mont rhs).val[8]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_8 : (lift_chunk_mont lhs).val[8]! + = lift_fe_mont (lhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_lhs_9 : (lift_chunk_mont lhs).val[9]! + = lift_fe_mont (lhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 9 (by rw [h_l]; decide)] + have h_lcm_rhs_8 : (lift_chunk_mont rhs).val[8]! + = lift_fe_mont (rhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_rhs_9 : (lift_chunk_mont rhs).val[9]! + = lift_fe_mont (rhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 9 (by rw [h_l]; decide)] + rw [h_lcm_lhs_8, h_lcm_lhs_9, h_lcm_rhs_8, h_lcm_rhs_9] + · -- Lane 10: touched by call 5 (nzeta2, even). + have h_r7_at_lane : r7.val[10]! = r5.val[10]! := by + rw [h_r7_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + have h_src_at_odd : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + have h_fe := h_r5_fe_e + simp only [ + show (2 * (5#usize : Std.Usize).val : Nat) = 10 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n2_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[10]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[10]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[10]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[10]!) + ((lift_chunk_mont rhs).val[10]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[11]!) + ((lift_chunk_mont rhs).val[11]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta2))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_10 : (lift_chunk_mont lhs).val[10]! + = lift_fe_mont (lhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_lhs_11 : (lift_chunk_mont lhs).val[11]! + = lift_fe_mont (lhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 11 (by rw [h_l]; decide)] + have h_lcm_rhs_10 : (lift_chunk_mont rhs).val[10]! + = lift_fe_mont (rhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_rhs_11 : (lift_chunk_mont rhs).val[11]! + = lift_fe_mont (rhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 11 (by rw [h_l]; decide)] + rw [h_lcm_lhs_10, h_lcm_lhs_11, h_lcm_rhs_10, h_lcm_rhs_11] + · -- Lane 11: touched by call 5 (nzeta2, odd). + have h_r7_at_lane : r7.val[11]! = r5.val[11]! := by + rw [h_r7_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + have h_src_at_odd : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + have h_fe := h_r5_fe_o + simp only [ + show (2 * (5#usize : Std.Usize).val : Nat) = 10 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[11]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[11]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[11]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[10]!) + ((lift_chunk_mont rhs).val[11]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[11]!) + ((lift_chunk_mont rhs).val[10]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_10 : (lift_chunk_mont lhs).val[10]! + = lift_fe_mont (lhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_lhs_11 : (lift_chunk_mont lhs).val[11]! + = lift_fe_mont (lhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 11 (by rw [h_l]; decide)] + have h_lcm_rhs_10 : (lift_chunk_mont rhs).val[10]! + = lift_fe_mont (rhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_rhs_11 : (lift_chunk_mont rhs).val[11]! + = lift_fe_mont (rhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 11 (by rw [h_l]; decide)] + rw [h_lcm_lhs_10, h_lcm_lhs_11, h_lcm_rhs_10, h_lcm_rhs_11] + · -- Lane 12: touched by call 6 (zeta3, even). + have h_r7_at_lane : r7.val[12]! = r6.val[12]! := by + rw [h_r7_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + have h_src_at_odd : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + have h_fe := h_r6_fe_e + simp only [ + show (2 * (6#usize : Std.Usize).val : Nat) = 12 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[12]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[12]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[12]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[12]!) + ((lift_chunk_mont rhs).val[12]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[13]!) + ((lift_chunk_mont rhs).val[13]!)) + (lift_fe_mont zeta3)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_12 : (lift_chunk_mont lhs).val[12]! + = lift_fe_mont (lhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_lhs_13 : (lift_chunk_mont lhs).val[13]! + = lift_fe_mont (lhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 13 (by rw [h_l]; decide)] + have h_lcm_rhs_12 : (lift_chunk_mont rhs).val[12]! + = lift_fe_mont (rhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_rhs_13 : (lift_chunk_mont rhs).val[13]! + = lift_fe_mont (rhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 13 (by rw [h_l]; decide)] + rw [h_lcm_lhs_12, h_lcm_lhs_13, h_lcm_rhs_12, h_lcm_rhs_13] + · -- Lane 13: touched by call 6 (zeta3, odd). + have h_r7_at_lane : r7.val[13]! = r6.val[13]! := by + rw [h_r7_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + have h_src_at_odd : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + have h_fe := h_r6_fe_o + simp only [ + show (2 * (6#usize : Std.Usize).val : Nat) = 12 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[13]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[13]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[13]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[12]!) + ((lift_chunk_mont rhs).val[13]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[13]!) + ((lift_chunk_mont rhs).val[12]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_12 : (lift_chunk_mont lhs).val[12]! + = lift_fe_mont (lhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_lhs_13 : (lift_chunk_mont lhs).val[13]! + = lift_fe_mont (lhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 13 (by rw [h_l]; decide)] + have h_lcm_rhs_12 : (lift_chunk_mont rhs).val[12]! + = lift_fe_mont (rhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_rhs_13 : (lift_chunk_mont rhs).val[13]! + = lift_fe_mont (rhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 13 (by rw [h_l]; decide)] + rw [h_lcm_lhs_12, h_lcm_lhs_13, h_lcm_rhs_12, h_lcm_rhs_13] + · -- Lane 14: touched by call 7 (nzeta3, even). + have h_src_at_even : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + have h_src_at_odd : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + have h_fe := h_r7_fe_e + simp only [ + show (2 * (7#usize : Std.Usize).val : Nat) = 14 from by decide] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n3_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[14]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[14]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[14]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[14]!) + ((lift_chunk_mont rhs).val[14]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[15]!) + ((lift_chunk_mont rhs).val[15]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta3))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_14 : (lift_chunk_mont lhs).val[14]! + = lift_fe_mont (lhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_lhs_15 : (lift_chunk_mont lhs).val[15]! + = lift_fe_mont (lhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 15 (by rw [h_l]; decide)] + have h_lcm_rhs_14 : (lift_chunk_mont rhs).val[14]! + = lift_fe_mont (rhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_rhs_15 : (lift_chunk_mont rhs).val[15]! + = lift_fe_mont (rhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 15 (by rw [h_l]; decide)] + rw [h_lcm_lhs_14, h_lcm_lhs_15, h_lcm_rhs_14, h_lcm_rhs_15] + · -- Lane 15: touched by call 7 (nzeta3, odd). + have h_src_at_even : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + have h_src_at_odd : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + have h_fe := h_r7_fe_o + simp only [ + show (2 * (7#usize : Std.Usize).val : Nat) = 14 from by decide] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[15]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[15]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[15]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[14]!) + ((lift_chunk_mont rhs).val[15]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[15]!) + ((lift_chunk_mont rhs).val[14]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_14 : (lift_chunk_mont lhs).val[14]! + = lift_fe_mont (lhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_lhs_15 : (lift_chunk_mont lhs).val[15]! + = lift_fe_mont (lhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 15 (by rw [h_l]; decide)] + have h_lcm_rhs_14 : (lift_chunk_mont rhs).val[14]! + = lift_fe_mont (rhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_rhs_15 : (lift_chunk_mont rhs).val[15]! + = lift_fe_mont (rhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 15 (by rw [h_l]; decide)] + rw [h_lcm_lhs_14, h_lcm_lhs_15, h_lcm_rhs_14, h_lcm_rhs_15] + · -- Spec.ntt_multiply_cache_post — cache POST: 8 conjuncts (one per pair). + intro i + -- For each pair index i, the cache7 lane at slot i.val equals cache_i lane at slot i.val + -- (since calls i+1..7 write different slots). Conclude via h_c{j}_unc' chain + the + -- per-pair canonicity + FE-equation conjuncts h_c{j}_canon, h_c{j}_fe. + rcases i with ⟨i, hi⟩ + -- Normalise (J#usize).val = J in the per-pair canonicity/FE hypotheses. + have h_uv0 : (0#usize : Std.Usize).val = 0 := rfl + have h_uv1 : (1#usize : Std.Usize).val = 1 := rfl + have h_uv2 : (2#usize : Std.Usize).val = 2 := rfl + have h_uv3 : (3#usize : Std.Usize).val = 3 := rfl + have h_uv4 : (4#usize : Std.Usize).val = 4 := rfl + have h_uv5 : (5#usize : Std.Usize).val = 5 := rfl + have h_uv6 : (6#usize : Std.Usize).val = 6 := rfl + have h_uv7 : (7#usize : Std.Usize).val = 7 := rfl + -- Index-arithmetic normalisations on (2 * (J#usize).val + 1). + interval_cases i + · -- Pair 0: cache7[0] = cache0[0] (calls 1..7 don't touch slot 0). + have h_chain : cache7.elements.val[0]! = cache0.elements.val[0]! := by + rw [h_c7_unc' 0 (by decide) (by decide)] + rw [h_c6_unc' 0 (by decide) (by decide)] + rw [h_c5_unc' 0 (by decide) (by decide)] + rw [h_c4_unc' 0 (by decide) (by decide)] + rw [h_c3_unc' 0 (by decide) (by decide)] + rw [h_c2_unc' 0 (by decide) (by decide)] + rw [h_c1_unc' 0 (by decide) (by decide)] + refine ⟨?_, ?_⟩ + · -- canonical: cache0[0] ≤ 3328. + rw [show ((⟨0, hi⟩ : Fin 8) : Fin 8).val = 0 from rfl, h_chain] + rw [h_uv0] at h_c0_canon; exact h_c0_canon + · -- FE eq: lift_fe_mont cache7[0] = mul_pure (lift_fe_mont rhs[1]) (zeta0_fe). + rw [show ((⟨0, hi⟩ : Fin 8) : Fin 8).val = 0 from rfl, h_chain] + rw [h_uv0] at h_c0_fe; rw [h_c0_fe] + -- effective_zeta_fe ⟨0, _⟩ ... = zeta0_fe. + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * 0 + 1]!)) + (Spec.effective_zeta_fe ⟨0, hi⟩ + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) + unfold Spec.effective_zeta_fe; simp + · -- Pair 1: cache7[1] = cache1[1]. + have h_chain : cache7.elements.val[1]! = cache1.elements.val[1]! := by + rw [h_c7_unc' 1 (by decide) (by decide)] + rw [h_c6_unc' 1 (by decide) (by decide)] + rw [h_c5_unc' 1 (by decide) (by decide)] + rw [h_c4_unc' 1 (by decide) (by decide)] + rw [h_c3_unc' 1 (by decide) (by decide)] + rw [h_c2_unc' 1 (by decide) (by decide)] + refine ⟨?_, ?_⟩ + · rw [show ((⟨1, hi⟩ : Fin 8) : Fin 8).val = 1 from rfl, h_chain] + rw [h_uv1] at h_c1_canon; exact h_c1_canon + · rw [show ((⟨1, hi⟩ : Fin 8) : Fin 8).val = 1 from rfl, h_chain] + rw [h_uv1] at h_c1_fe; rw [h_c1_fe, h_n0_fe] + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * 1 + 1]!)) + (Spec.effective_zeta_fe ⟨1, hi⟩ + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) + unfold Spec.effective_zeta_fe; simp + · -- Pair 2: cache7[2] = cache2[2]. + have h_chain : cache7.elements.val[2]! = cache2.elements.val[2]! := by + rw [h_c7_unc' 2 (by decide) (by decide)] + rw [h_c6_unc' 2 (by decide) (by decide)] + rw [h_c5_unc' 2 (by decide) (by decide)] + rw [h_c4_unc' 2 (by decide) (by decide)] + rw [h_c3_unc' 2 (by decide) (by decide)] + refine ⟨?_, ?_⟩ + · rw [show ((⟨2, hi⟩ : Fin 8) : Fin 8).val = 2 from rfl, h_chain] + rw [h_uv2] at h_c2_canon; exact h_c2_canon + · rw [show ((⟨2, hi⟩ : Fin 8) : Fin 8).val = 2 from rfl, h_chain] + rw [h_uv2] at h_c2_fe; rw [h_c2_fe] + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * 2 + 1]!)) + (Spec.effective_zeta_fe ⟨2, hi⟩ + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) + unfold Spec.effective_zeta_fe; simp + · -- Pair 3: cache7[3] = cache3[3]. + have h_chain : cache7.elements.val[3]! = cache3.elements.val[3]! := by + rw [h_c7_unc' 3 (by decide) (by decide)] + rw [h_c6_unc' 3 (by decide) (by decide)] + rw [h_c5_unc' 3 (by decide) (by decide)] + rw [h_c4_unc' 3 (by decide) (by decide)] + refine ⟨?_, ?_⟩ + · rw [show ((⟨3, hi⟩ : Fin 8) : Fin 8).val = 3 from rfl, h_chain] + rw [h_uv3] at h_c3_canon; exact h_c3_canon + · rw [show ((⟨3, hi⟩ : Fin 8) : Fin 8).val = 3 from rfl, h_chain] + rw [h_uv3] at h_c3_fe; rw [h_c3_fe, h_n1_fe] + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * 3 + 1]!)) + (Spec.effective_zeta_fe ⟨3, hi⟩ + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) + unfold Spec.effective_zeta_fe; simp + · -- Pair 4: cache7[4] = cache4[4]. + have h_chain : cache7.elements.val[4]! = cache4.elements.val[4]! := by + rw [h_c7_unc' 4 (by decide) (by decide)] + rw [h_c6_unc' 4 (by decide) (by decide)] + rw [h_c5_unc' 4 (by decide) (by decide)] + refine ⟨?_, ?_⟩ + · rw [show ((⟨4, hi⟩ : Fin 8) : Fin 8).val = 4 from rfl, h_chain] + rw [h_uv4] at h_c4_canon; exact h_c4_canon + · rw [show ((⟨4, hi⟩ : Fin 8) : Fin 8).val = 4 from rfl, h_chain] + rw [h_uv4] at h_c4_fe; rw [h_c4_fe] + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * 4 + 1]!)) + (Spec.effective_zeta_fe ⟨4, hi⟩ + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) + unfold Spec.effective_zeta_fe; simp + · -- Pair 5: cache7[5] = cache5[5]. + have h_chain : cache7.elements.val[5]! = cache5.elements.val[5]! := by + rw [h_c7_unc' 5 (by decide) (by decide)] + rw [h_c6_unc' 5 (by decide) (by decide)] + refine ⟨?_, ?_⟩ + · rw [show ((⟨5, hi⟩ : Fin 8) : Fin 8).val = 5 from rfl, h_chain] + rw [h_uv5] at h_c5_canon; exact h_c5_canon + · rw [show ((⟨5, hi⟩ : Fin 8) : Fin 8).val = 5 from rfl, h_chain] + rw [h_uv5] at h_c5_fe; rw [h_c5_fe, h_n2_fe] + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * 5 + 1]!)) + (Spec.effective_zeta_fe ⟨5, hi⟩ + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) + unfold Spec.effective_zeta_fe; simp + · -- Pair 6: cache7[6] = cache6[6]. + have h_chain : cache7.elements.val[6]! = cache6.elements.val[6]! := by + rw [h_c7_unc' 6 (by decide) (by decide)] + refine ⟨?_, ?_⟩ + · rw [show ((⟨6, hi⟩ : Fin 8) : Fin 8).val = 6 from rfl, h_chain] + rw [h_uv6] at h_c6_canon; exact h_c6_canon + · rw [show ((⟨6, hi⟩ : Fin 8) : Fin 8).val = 6 from rfl, h_chain] + rw [h_uv6] at h_c6_fe; rw [h_c6_fe] + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * 6 + 1]!)) + (Spec.effective_zeta_fe ⟨6, hi⟩ + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) + unfold Spec.effective_zeta_fe; simp + · -- Pair 7: cache7[7] (last call wrote slot 7). + refine ⟨?_, ?_⟩ + · rw [show ((⟨7, hi⟩ : Fin 8) : Fin 8).val = 7 from rfl] + rw [h_uv7] at h_c7_canon; exact h_c7_canon + · rw [show ((⟨7, hi⟩ : Fin 8) : Fin 8).val = 7 from rfl] + rw [h_uv7] at h_c7_fe; rw [h_c7_fe, h_n3_fe] + show _ = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * 7 + 1]!)) + (Spec.effective_zeta_fe ⟨7, hi⟩ + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) + unfold Spec.effective_zeta_fe; simp + · -- Lanes 8..15: untouched. + intro k hk h8 + -- cache7[k] = cache6[k] = cache5[k] = cache4[k] = cache3[k] = cache2[k] = cache1[k] = cache0[k] = cache[k] + -- since each cacheJ.unc' k holds whenever k ≠ J. + have hk_ne_7 : k ≠ 7 := by omega + have hk_ne_6 : k ≠ 6 := by omega + have hk_ne_5 : k ≠ 5 := by omega + have hk_ne_4 : k ≠ 4 := by omega + have hk_ne_3 : k ≠ 3 := by omega + have hk_ne_2 : k ≠ 2 := by omega + have hk_ne_1 : k ≠ 1 := by omega + have hk_ne_0 : k ≠ 0 := by omega + rw [h_c7_unc' k hk hk_ne_7] + rw [h_c6_unc' k hk hk_ne_6] + rw [h_c5_unc' k hk hk_ne_5] + rw [h_c4_unc' k hk hk_ne_4] + rw [h_c3_unc' k hk hk_ne_3] + rw [h_c2_unc' k hk hk_ne_2] + rw [h_c1_unc' k hk hk_ne_1] + rw [h_c0_unc' k hk hk_ne_0] + +set_option maxHeartbeats 16000000 in +/-- L2.8d — `vector.portable.ntt.accumulating_ntt_multiply_use_cache`: + cache-using variant. The impl chains 8 + `accumulating_ntt_multiply_binomials_use_cache` calls; each reads + `cache[i]` instead of recomputing `mont_reduce(b[2i+1]·zeta_i)`. + + POST identical to `accumulating_ntt_multiply_fc` (the four zetas + are ghost arguments — they appear only in the cache PRE-condition + and the POST's `ntt_multiply_base_case_post`, NOT in the impl call). + Compose with `accumulating_ntt_multiply_fill_cache_fc`: discharge + the cache PRE from a prior `_fill_cache` POST's cache conjunct. -/ +@[spec] +theorem accumulating_ntt_multiply_use_cache_fc + (lhs rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (out : Aeneas.Std.Slice Std.I32) + (cache : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 zeta2 zeta3 : Std.I16) + (h_out_len : out.length = 16) + (h_lhs : ∀ j : Fin 16, (lhs.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_rhs : ∀ j : Fin 16, (rhs.elements.val[j.val]!).val.natAbs ≤ 3328) + (h_zeta0 : zeta0.val.natAbs ≤ 1664) + (h_zeta1 : zeta1.val.natAbs ≤ 1664) + (h_zeta2 : zeta2.val.natAbs ≤ 1664) + (h_zeta3 : zeta3.val.natAbs ≤ 1664) + (h_out_bnd : ∀ k : Fin 16, (out.val[k.val]!).val.natAbs ≤ 2^30) + (h_cache : Spec.ntt_multiply_cache_post rhs zeta0 zeta1 zeta2 zeta3 cache) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_use_cache + lhs rhs out cache + ⦃ ⇓ r => ⌜ r.length = 16 ∧ + (∀ k : Fin 16, (r.val[k.val]!).val.natAbs + ≤ (out.val[k.val]!).val.natAbs + 2^25) ∧ + ntt_multiply_base_case_post lhs rhs + zeta0 zeta1 zeta2 zeta3 out r ⌝ ⦄ := by + have h_cache_canon : ∀ i : Fin 8, + (cache.elements.val[i.val]!).val.natAbs ≤ 3328 := fun i => (h_cache i).1 + have h_cache_fe : ∀ i : Fin 8, + lift_fe_mont (cache.elements.val[i.val]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[2 * i.val + 1]!)) + (Spec.effective_zeta_fe i + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)) := fun i => (h_cache i).2 + have h_out_bnd_universal : ∀ k : Fin 16, (out.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := by + intro k; have := h_out_bnd k; omega + -- Cache FE-equations specialised at each pair index 0..7 (Spec.effective_zeta_fe + -- collapses to the appropriate zeta_j or neg_pure zeta_j). + have h_cache0_fe : lift_fe_mont (cache.elements.val[0]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[1]!)) (lift_fe_mont zeta0) := by + have h := h_cache_fe ⟨0, by decide⟩ + rw [show ((⟨0, by decide⟩ : Fin 8) : Fin 8).val = 0 from rfl] at h + rw [h]; unfold Spec.effective_zeta_fe; simp + have h_cache1_fe : lift_fe_mont (cache.elements.val[1]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[3]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta0)) := by + have h := h_cache_fe ⟨1, by decide⟩ + rw [show ((⟨1, by decide⟩ : Fin 8) : Fin 8).val = 1 from rfl] at h + rw [h]; unfold Spec.effective_zeta_fe; simp + have h_cache2_fe : lift_fe_mont (cache.elements.val[2]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[5]!)) (lift_fe_mont zeta1) := by + have h := h_cache_fe ⟨2, by decide⟩ + rw [show ((⟨2, by decide⟩ : Fin 8) : Fin 8).val = 2 from rfl] at h + rw [h]; unfold Spec.effective_zeta_fe; simp + have h_cache3_fe : lift_fe_mont (cache.elements.val[3]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[7]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta1)) := by + have h := h_cache_fe ⟨3, by decide⟩ + rw [show ((⟨3, by decide⟩ : Fin 8) : Fin 8).val = 3 from rfl] at h + rw [h]; unfold Spec.effective_zeta_fe; simp + have h_cache4_fe : lift_fe_mont (cache.elements.val[4]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[9]!)) (lift_fe_mont zeta2) := by + have h := h_cache_fe ⟨4, by decide⟩ + rw [show ((⟨4, by decide⟩ : Fin 8) : Fin 8).val = 4 from rfl] at h + rw [h]; unfold Spec.effective_zeta_fe; simp + have h_cache5_fe : lift_fe_mont (cache.elements.val[5]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[11]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta2)) := by + have h := h_cache_fe ⟨5, by decide⟩ + rw [show ((⟨5, by decide⟩ : Fin 8) : Fin 8).val = 5 from rfl] at h + rw [h]; unfold Spec.effective_zeta_fe; simp + have h_cache6_fe : lift_fe_mont (cache.elements.val[6]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[13]!)) (lift_fe_mont zeta3) := by + have h := h_cache_fe ⟨6, by decide⟩ + rw [show ((⟨6, by decide⟩ : Fin 8) : Fin 8).val = 6 from rfl] at h + rw [h]; unfold Spec.effective_zeta_fe; simp + have h_cache7_fe : lift_fe_mont (cache.elements.val[7]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont (rhs.elements.val[15]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta3)) := by + have h := h_cache_fe ⟨7, by decide⟩ + rw [show ((⟨7, by decide⟩ : Fin 8) : Fin 8).val = 7 from rfl] at h + rw [h]; unfold Spec.effective_zeta_fe; simp + -- ===== 8 chained calls ===== + -- Call 0: + obtain ⟨r0, h_r0_eq, h_r0_len, h_r0_unc, h_r0_bnd_e, h_r0_bnd_o, h_r0_fe_e, h_r0_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_use_cache_fc lhs rhs 0#usize out cache + (by decide) h_out_len h_lhs h_rhs (h_cache_canon ⟨0, by decide⟩) + h_out_bnd_universal) + have h_r0_at_even : (r0.val[0]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (0#usize : Std.Usize).val : Nat) = 0 := by decide + have h_b := h_r0_bnd_e; rw [h_eq] at h_b + have h_out_le := h_out_bnd ⟨0, by decide⟩; simp only at h_out_le; omega + have h_r0_at_odd : (r0.val[1]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (0#usize : Std.Usize).val + 1 : Nat) = 1 := by decide + have h_b := h_r0_bnd_o; rw [h_eq] at h_b + have h_out_le := h_out_bnd ⟨1, by decide⟩; simp only at h_out_le; omega + have h_r0_unc' : ∀ k : Nat, k < 16 → k ≠ 0 → k ≠ 1 → + r0.val[k]! = out.val[k]! := by + intro k hk hke hko + apply h_r0_unc k hk + · show k ≠ 2 * (0#usize : Std.Usize).val; rw [show (0#usize : Std.Usize).val = 0 from rfl]; omega + · show k ≠ 2 * (0#usize : Std.Usize).val + 1; rw [show (0#usize : Std.Usize).val = 0 from rfl]; omega + have h_r0_bnd_universal : ∀ k : Fin 16, (r0.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step out r0 0 (by decide) h_out_bnd_universal + h_r0_unc' h_r0_at_even h_r0_at_odd + -- Call 1: + obtain ⟨r1, h_r1_eq, h_r1_len, h_r1_unc, h_r1_bnd_e, h_r1_bnd_o, h_r1_fe_e, h_r1_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_use_cache_fc lhs rhs 1#usize r0 cache + (by decide) h_r0_len h_lhs h_rhs (h_cache_canon ⟨1, by decide⟩) + h_r0_bnd_universal) + have h_r1_at_even : (r1.val[2]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (1#usize : Std.Usize).val : Nat) = 2 := by decide + have h_b := h_r1_bnd_e; rw [h_eq] at h_b + have h_r0_eq2 : r0.val[2]! = out.val[2]! := h_r0_unc' 2 (by decide) (by decide) (by decide) + rw [h_r0_eq2] at h_b + have h_out_le := h_out_bnd ⟨2, by decide⟩; simp only at h_out_le; omega + have h_r1_at_odd : (r1.val[3]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (1#usize : Std.Usize).val + 1 : Nat) = 3 := by decide + have h_b := h_r1_bnd_o; rw [h_eq] at h_b + have h_r0_eq3 : r0.val[3]! = out.val[3]! := h_r0_unc' 3 (by decide) (by decide) (by decide) + rw [h_r0_eq3] at h_b + have h_out_le := h_out_bnd ⟨3, by decide⟩; simp only at h_out_le; omega + have h_r1_unc' : ∀ k : Nat, k < 16 → k ≠ 2 → k ≠ 3 → + r1.val[k]! = r0.val[k]! := by + intro k hk hke hko + apply h_r1_unc k hk + · show k ≠ 2 * (1#usize : Std.Usize).val; rw [show (1#usize : Std.Usize).val = 1 from rfl]; omega + · show k ≠ 2 * (1#usize : Std.Usize).val + 1; rw [show (1#usize : Std.Usize).val = 1 from rfl]; omega + have h_r1_bnd_universal : ∀ k : Fin 16, (r1.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r0 r1 1 (by decide) h_r0_bnd_universal + h_r1_unc' h_r1_at_even h_r1_at_odd + -- Call 2: + obtain ⟨r2, h_r2_eq, h_r2_len, h_r2_unc, h_r2_bnd_e, h_r2_bnd_o, h_r2_fe_e, h_r2_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_use_cache_fc lhs rhs 2#usize r1 cache + (by decide) h_r1_len h_lhs h_rhs (h_cache_canon ⟨2, by decide⟩) + h_r1_bnd_universal) + have h_r2_at_even : (r2.val[4]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (2#usize : Std.Usize).val : Nat) = 4 := by decide + have h_b := h_r2_bnd_e; rw [h_eq] at h_b + have h_r1_eq4 : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r1_eq4] at h_b + have h_out_le := h_out_bnd ⟨4, by decide⟩; simp only at h_out_le; omega + have h_r2_at_odd : (r2.val[5]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (2#usize : Std.Usize).val + 1 : Nat) = 5 := by decide + have h_b := h_r2_bnd_o; rw [h_eq] at h_b + have h_r1_eq5 : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r1_eq5] at h_b + have h_out_le := h_out_bnd ⟨5, by decide⟩; simp only at h_out_le; omega + have h_r2_unc' : ∀ k : Nat, k < 16 → k ≠ 4 → k ≠ 5 → + r2.val[k]! = r1.val[k]! := by + intro k hk hke hko + apply h_r2_unc k hk + · show k ≠ 2 * (2#usize : Std.Usize).val; rw [show (2#usize : Std.Usize).val = 2 from rfl]; omega + · show k ≠ 2 * (2#usize : Std.Usize).val + 1; rw [show (2#usize : Std.Usize).val = 2 from rfl]; omega + have h_r2_bnd_universal : ∀ k : Fin 16, (r2.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r1 r2 2 (by decide) h_r1_bnd_universal + h_r2_unc' h_r2_at_even h_r2_at_odd + -- Call 3: + obtain ⟨r3, h_r3_eq, h_r3_len, h_r3_unc, h_r3_bnd_e, h_r3_bnd_o, h_r3_fe_e, h_r3_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_use_cache_fc lhs rhs 3#usize r2 cache + (by decide) h_r2_len h_lhs h_rhs (h_cache_canon ⟨3, by decide⟩) + h_r2_bnd_universal) + have h_r3_at_even : (r3.val[6]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (3#usize : Std.Usize).val : Nat) = 6 := by decide + have h_b := h_r3_bnd_e; rw [h_eq] at h_b + have h_r2_eq6 : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r2_eq6] at h_b + have h_out_le := h_out_bnd ⟨6, by decide⟩; simp only at h_out_le; omega + have h_r3_at_odd : (r3.val[7]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (3#usize : Std.Usize).val + 1 : Nat) = 7 := by decide + have h_b := h_r3_bnd_o; rw [h_eq] at h_b + have h_r2_eq7 : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r2_eq7] at h_b + have h_out_le := h_out_bnd ⟨7, by decide⟩; simp only at h_out_le; omega + have h_r3_unc' : ∀ k : Nat, k < 16 → k ≠ 6 → k ≠ 7 → + r3.val[k]! = r2.val[k]! := by + intro k hk hke hko + apply h_r3_unc k hk + · show k ≠ 2 * (3#usize : Std.Usize).val; rw [show (3#usize : Std.Usize).val = 3 from rfl]; omega + · show k ≠ 2 * (3#usize : Std.Usize).val + 1; rw [show (3#usize : Std.Usize).val = 3 from rfl]; omega + have h_r3_bnd_universal : ∀ k : Fin 16, (r3.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r2 r3 3 (by decide) h_r2_bnd_universal + h_r3_unc' h_r3_at_even h_r3_at_odd + -- Call 4: + obtain ⟨r4, h_r4_eq, h_r4_len, h_r4_unc, h_r4_bnd_e, h_r4_bnd_o, h_r4_fe_e, h_r4_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_use_cache_fc lhs rhs 4#usize r3 cache + (by decide) h_r3_len h_lhs h_rhs (h_cache_canon ⟨4, by decide⟩) + h_r3_bnd_universal) + have h_r4_at_even : (r4.val[8]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (4#usize : Std.Usize).val : Nat) = 8 := by decide + have h_b := h_r4_bnd_e; rw [h_eq] at h_b + have h_r3_eq8 : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r3_eq8] at h_b + have h_out_le := h_out_bnd ⟨8, by decide⟩; simp only at h_out_le; omega + have h_r4_at_odd : (r4.val[9]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (4#usize : Std.Usize).val + 1 : Nat) = 9 := by decide + have h_b := h_r4_bnd_o; rw [h_eq] at h_b + have h_r3_eq9 : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r3_eq9] at h_b + have h_out_le := h_out_bnd ⟨9, by decide⟩; simp only at h_out_le; omega + have h_r4_unc' : ∀ k : Nat, k < 16 → k ≠ 8 → k ≠ 9 → + r4.val[k]! = r3.val[k]! := by + intro k hk hke hko + apply h_r4_unc k hk + · show k ≠ 2 * (4#usize : Std.Usize).val; rw [show (4#usize : Std.Usize).val = 4 from rfl]; omega + · show k ≠ 2 * (4#usize : Std.Usize).val + 1; rw [show (4#usize : Std.Usize).val = 4 from rfl]; omega + have h_r4_bnd_universal : ∀ k : Fin 16, (r4.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r3 r4 4 (by decide) h_r3_bnd_universal + h_r4_unc' h_r4_at_even h_r4_at_odd + -- Call 5: + obtain ⟨r5, h_r5_eq, h_r5_len, h_r5_unc, h_r5_bnd_e, h_r5_bnd_o, h_r5_fe_e, h_r5_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_use_cache_fc lhs rhs 5#usize r4 cache + (by decide) h_r4_len h_lhs h_rhs (h_cache_canon ⟨5, by decide⟩) + h_r4_bnd_universal) + have h_r5_at_even : (r5.val[10]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (5#usize : Std.Usize).val : Nat) = 10 := by decide + have h_b := h_r5_bnd_e; rw [h_eq] at h_b + have h_r4_eq10 : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r4_eq10] at h_b + have h_out_le := h_out_bnd ⟨10, by decide⟩; simp only at h_out_le; omega + have h_r5_at_odd : (r5.val[11]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (5#usize : Std.Usize).val + 1 : Nat) = 11 := by decide + have h_b := h_r5_bnd_o; rw [h_eq] at h_b + have h_r4_eq11 : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r4_eq11] at h_b + have h_out_le := h_out_bnd ⟨11, by decide⟩; simp only at h_out_le; omega + have h_r5_unc' : ∀ k : Nat, k < 16 → k ≠ 10 → k ≠ 11 → + r5.val[k]! = r4.val[k]! := by + intro k hk hke hko + apply h_r5_unc k hk + · show k ≠ 2 * (5#usize : Std.Usize).val; rw [show (5#usize : Std.Usize).val = 5 from rfl]; omega + · show k ≠ 2 * (5#usize : Std.Usize).val + 1; rw [show (5#usize : Std.Usize).val = 5 from rfl]; omega + have h_r5_bnd_universal : ∀ k : Fin 16, (r5.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r4 r5 5 (by decide) h_r4_bnd_universal + h_r5_unc' h_r5_at_even h_r5_at_odd + -- Call 6: + obtain ⟨r6, h_r6_eq, h_r6_len, h_r6_unc, h_r6_bnd_e, h_r6_bnd_o, h_r6_fe_e, h_r6_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_use_cache_fc lhs rhs 6#usize r5 cache + (by decide) h_r5_len h_lhs h_rhs (h_cache_canon ⟨6, by decide⟩) + h_r5_bnd_universal) + have h_r6_at_even : (r6.val[12]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (6#usize : Std.Usize).val : Nat) = 12 := by decide + have h_b := h_r6_bnd_e; rw [h_eq] at h_b + have h_r5_eq12 : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r5_eq12] at h_b + have h_out_le := h_out_bnd ⟨12, by decide⟩; simp only at h_out_le; omega + have h_r6_at_odd : (r6.val[13]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (6#usize : Std.Usize).val + 1 : Nat) = 13 := by decide + have h_b := h_r6_bnd_o; rw [h_eq] at h_b + have h_r5_eq13 : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r5_eq13] at h_b + have h_out_le := h_out_bnd ⟨13, by decide⟩; simp only at h_out_le; omega + have h_r6_unc' : ∀ k : Nat, k < 16 → k ≠ 12 → k ≠ 13 → + r6.val[k]! = r5.val[k]! := by + intro k hk hke hko + apply h_r6_unc k hk + · show k ≠ 2 * (6#usize : Std.Usize).val; rw [show (6#usize : Std.Usize).val = 6 from rfl]; omega + · show k ≠ 2 * (6#usize : Std.Usize).val + 1; rw [show (6#usize : Std.Usize).val = 6 from rfl]; omega + have h_r6_bnd_universal : ∀ k : Fin 16, (r6.val[k.val]!).val.natAbs ≤ 2^30 + 2^25 := + L2_8c.bnd_universal_step r5 r6 6 (by decide) h_r5_bnd_universal + h_r6_unc' h_r6_at_even h_r6_at_odd + -- Call 7: + obtain ⟨r7, h_r7_eq, h_r7_len, h_r7_unc, h_r7_bnd_e, h_r7_bnd_o, h_r7_fe_e, h_r7_fe_o⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_binomials_use_cache_fc lhs rhs 7#usize r6 cache + (by decide) h_r6_len h_lhs h_rhs (h_cache_canon ⟨7, by decide⟩) + h_r6_bnd_universal) + have h_r7_at_even : (r7.val[14]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (7#usize : Std.Usize).val : Nat) = 14 := by decide + have h_b := h_r7_bnd_e; rw [h_eq] at h_b + have h_r6_eq14 : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r6_eq14] at h_b + have h_out_le := h_out_bnd ⟨14, by decide⟩; simp only at h_out_le; omega + have h_r7_at_odd : (r7.val[15]!).val.natAbs ≤ 2^30 + 2^25 := by + have h_eq : (2 * (7#usize : Std.Usize).val + 1 : Nat) = 15 := by decide + have h_b := h_r7_bnd_o; rw [h_eq] at h_b + have h_r6_eq15 : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r6_eq15] at h_b + have h_out_le := h_out_bnd ⟨15, by decide⟩; simp only at h_out_le; omega + have h_r7_unc' : ∀ k : Nat, k < 16 → k ≠ 14 → k ≠ 15 → + r7.val[k]! = r6.val[k]! := by + intro k hk hke hko + apply h_r7_unc k hk + · show k ≠ 2 * (7#usize : Std.Usize).val; rw [show (7#usize : Std.Usize).val = 7 from rfl]; omega + · show k ≠ 2 * (7#usize : Std.Usize).val + 1; rw [show (7#usize : Std.Usize).val = 7 from rfl]; omega + -- ===== Pre-rewrite each h_r{i}_fe_e to L2.8c form ===== + -- h_r{i}_fe_e (even) has the shape: + -- mont_reduce_pure (lift_fe_int r{i}[2j].val) = add_pure (mr prev[2j]) (add_pure + -- (mul_pure (lift lhs[2j]) (lift rhs[2j])) (mul_pure (lift lhs[2j+1]) c_m)) + -- where c_m = lift_fe_mont cache[j.val]. Using h_cache{j}_fe + mul_pure_assoc: + -- mul_pure (lift lhs[2j+1]) c_m = mul_pure (lift lhs[2j+1]) (mul_pure (lift rhs[2j+1]) zeta_eff) + -- = mul_pure (mul_pure (lift lhs[2j+1]) (lift rhs[2j+1])) zeta_eff. + -- That's the L2.8c shape. + have h_uv0 : (0#usize : Std.Usize).val = 0 := rfl + have h_uv1 : (1#usize : Std.Usize).val = 1 := rfl + have h_uv2 : (2#usize : Std.Usize).val = 2 := rfl + have h_uv3 : (3#usize : Std.Usize).val = 3 := rfl + have h_uv4 : (4#usize : Std.Usize).val = 4 := rfl + have h_uv5 : (5#usize : Std.Usize).val = 5 := rfl + have h_uv6 : (6#usize : Std.Usize).val = 6 := rfl + have h_uv7 : (7#usize : Std.Usize).val = 7 := rfl + rw [h_uv0] at h_r0_fe_e h_r0_fe_o + rw [h_uv1] at h_r1_fe_e h_r1_fe_o + rw [h_uv2] at h_r2_fe_e h_r2_fe_o + rw [h_uv3] at h_r3_fe_e h_r3_fe_o + rw [h_uv4] at h_r4_fe_e h_r4_fe_o + rw [h_uv5] at h_r5_fe_e h_r5_fe_o + rw [h_uv6] at h_r6_fe_e h_r6_fe_o + rw [h_uv7] at h_r7_fe_e h_r7_fe_o + rw [h_cache0_fe, L2_8d.mul_pure_assoc] at h_r0_fe_e + rw [h_cache1_fe, L2_8d.mul_pure_assoc] at h_r1_fe_e + rw [h_cache2_fe, L2_8d.mul_pure_assoc] at h_r2_fe_e + rw [h_cache3_fe, L2_8d.mul_pure_assoc] at h_r3_fe_e + rw [h_cache4_fe, L2_8d.mul_pure_assoc] at h_r4_fe_e + rw [h_cache5_fe, L2_8d.mul_pure_assoc] at h_r5_fe_e + rw [h_cache6_fe, L2_8d.mul_pure_assoc] at h_r6_fe_e + rw [h_cache7_fe, L2_8d.mul_pure_assoc] at h_r7_fe_e + -- After rewrite: h_r{i}_fe_e matches L2.8c per-pair FE shape with effective zeta. + -- Now h_n0_fe..h_n3_fe are unbound. Set their identity by way of synthesis from + -- effective_zeta_fe. + have h_n0_fe : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta0)) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta0) := rfl + have h_n1_fe : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta1)) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta1) := rfl + have h_n2_fe : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta2)) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta2) := rfl + have h_n3_fe : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta3)) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta3) := rfl + -- Compose monadic body. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_use_cache + lhs rhs out cache = .ok r7 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_use_cache + simp only [h_r0_eq, h_r1_eq, h_r2_eq, h_r3_eq, + h_r4_eq, h_r5_eq, h_r6_eq, h_r7_eq, + Aeneas.Std.bind_tc_ok] + apply triple_of_ok_fc h_body + -- POST: 3-conjunct. + refine ⟨h_r7_len, ?_, ?_⟩ + · -- Relative bound (same 16-way enumeration as L2.8c). + intro k + rcases k with ⟨k, hk⟩ + interval_cases k + · have h_r7_at_0 : r7.val[0]! = r0.val[0]! := by + rw [h_r7_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r7_at_0] + have h_b := h_r0_bnd_e + rw [show (2 * (0#usize : Std.Usize).val : Nat) = 0 from by decide] at h_b + exact h_b + · have h_r7_at_1 : r7.val[1]! = r0.val[1]! := by + rw [h_r7_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r7_at_1] + have h_b := h_r0_bnd_o + rw [show (2 * (0#usize : Std.Usize).val + 1 : Nat) = 1 from by decide] at h_b + exact h_b + · have h_r7_at_2 : r7.val[2]! = r1.val[2]! := by + rw [h_r7_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r7_at_2] + have h_b := h_r1_bnd_e + rw [show (2 * (1#usize : Std.Usize).val : Nat) = 2 from by decide] at h_b + have h_r0_at_2 : r0.val[2]! = out.val[2]! := h_r0_unc' 2 (by decide) (by decide) (by decide) + rw [h_r0_at_2] at h_b + exact h_b + · have h_r7_at_3 : r7.val[3]! = r1.val[3]! := by + rw [h_r7_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r7_at_3] + have h_b := h_r1_bnd_o + rw [show (2 * (1#usize : Std.Usize).val + 1 : Nat) = 3 from by decide] at h_b + have h_r0_at_3 : r0.val[3]! = out.val[3]! := h_r0_unc' 3 (by decide) (by decide) (by decide) + rw [h_r0_at_3] at h_b + exact h_b + · have h_r7_at_4 : r7.val[4]! = r2.val[4]! := by + rw [h_r7_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r7_at_4] + have h_b := h_r2_bnd_e + rw [show (2 * (2#usize : Std.Usize).val : Nat) = 4 from by decide] at h_b + have h_r1_at_4 : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r1_at_4] at h_b + exact h_b + · have h_r7_at_5 : r7.val[5]! = r2.val[5]! := by + rw [h_r7_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r7_at_5] + have h_b := h_r2_bnd_o + rw [show (2 * (2#usize : Std.Usize).val + 1 : Nat) = 5 from by decide] at h_b + have h_r1_at_5 : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r1_at_5] at h_b + exact h_b + · have h_r7_at_6 : r7.val[6]! = r3.val[6]! := by + rw [h_r7_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r7_at_6] + have h_b := h_r3_bnd_e + rw [show (2 * (3#usize : Std.Usize).val : Nat) = 6 from by decide] at h_b + have h_r2_at_6 : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r2_at_6] at h_b + exact h_b + · have h_r7_at_7 : r7.val[7]! = r3.val[7]! := by + rw [h_r7_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r7_at_7] + have h_b := h_r3_bnd_o + rw [show (2 * (3#usize : Std.Usize).val + 1 : Nat) = 7 from by decide] at h_b + have h_r2_at_7 : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r2_at_7] at h_b + exact h_b + · have h_r7_at_8 : r7.val[8]! = r4.val[8]! := by + rw [h_r7_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r7_at_8] + have h_b := h_r4_bnd_e + rw [show (2 * (4#usize : Std.Usize).val : Nat) = 8 from by decide] at h_b + have h_r3_at_8 : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r3_at_8] at h_b + exact h_b + · have h_r7_at_9 : r7.val[9]! = r4.val[9]! := by + rw [h_r7_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r7_at_9] + have h_b := h_r4_bnd_o + rw [show (2 * (4#usize : Std.Usize).val + 1 : Nat) = 9 from by decide] at h_b + have h_r3_at_9 : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r3_at_9] at h_b + exact h_b + · have h_r7_at_10 : r7.val[10]! = r5.val[10]! := by + rw [h_r7_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r7_at_10] + have h_b := h_r5_bnd_e + rw [show (2 * (5#usize : Std.Usize).val : Nat) = 10 from by decide] at h_b + have h_r4_at_10 : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r4_at_10] at h_b + exact h_b + · have h_r7_at_11 : r7.val[11]! = r5.val[11]! := by + rw [h_r7_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r7_at_11] + have h_b := h_r5_bnd_o + rw [show (2 * (5#usize : Std.Usize).val + 1 : Nat) = 11 from by decide] at h_b + have h_r4_at_11 : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r4_at_11] at h_b + exact h_b + · have h_r7_at_12 : r7.val[12]! = r6.val[12]! := by + rw [h_r7_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r7_at_12] + have h_b := h_r6_bnd_e + rw [show (2 * (6#usize : Std.Usize).val : Nat) = 12 from by decide] at h_b + have h_r5_at_12 : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r5_at_12] at h_b + exact h_b + · have h_r7_at_13 : r7.val[13]! = r6.val[13]! := by + rw [h_r7_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r7_at_13] + have h_b := h_r6_bnd_o + rw [show (2 * (6#usize : Std.Usize).val + 1 : Nat) = 13 from by decide] at h_b + have h_r5_at_13 : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r5_at_13] at h_b + exact h_b + · have h_b := h_r7_bnd_e + rw [show (2 * (7#usize : Std.Usize).val : Nat) = 14 from by decide] at h_b + have h_r6_at_14 : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r6_at_14] at h_b + exact h_b + · have h_b := h_r7_bnd_o + rw [show (2 * (7#usize : Std.Usize).val + 1 : Nat) = 15 from by decide] at h_b + have h_r6_at_15 : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r6_at_15] at h_b + exact h_b + · -- ntt_multiply_base_case_post: per-lane FE equation. + unfold ntt_multiply_base_case_post ntt_multiply_base_case_alg + apply Subtype.ext + have h_lhs_val : (Spec.chunk_reducing_from_i32_array_pure r7).val + = (List.range 16).map (fun i => Spec.mont_reduce_pure (lift_fe_int (r7.val[i]!).val)) := by + unfold Spec.chunk_reducing_from_i32_array_pure; rfl + have h_rhs_val : (Spec.chunk_add_pure + (Spec.chunk_reducing_from_i32_array_pure out) + (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3))).val + = (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.chunk_reducing_from_i32_array_pure out).val[i]!) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[i]!)) := by + unfold Spec.chunk_add_pure; rfl + rw [h_lhs_val, h_rhs_val] + apply List.ext_getElem + · simp + · intro k hk1 hk2 + have hk : k < 16 := by simp at hk1; exact hk1 + rw [List.getElem_map, List.getElem_map, List.getElem_range] + interval_cases k + · -- Lane 0: touched by call 0 (zeta0, even). + have h_r7_at_lane : r7.val[0]! = r0.val[0]! := by + rw [h_r7_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 0 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_fe := h_r0_fe_e + simp only [] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[0]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[0]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[0]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[0]!) + ((lift_chunk_mont rhs).val[0]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[1]!) + ((lift_chunk_mont rhs).val[1]!)) + (lift_fe_mont zeta0)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_0 : (lift_chunk_mont lhs).val[0]! + = lift_fe_mont (lhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_lhs_1 : (lift_chunk_mont lhs).val[1]! + = lift_fe_mont (lhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 1 (by rw [h_l]; decide)] + have h_lcm_rhs_0 : (lift_chunk_mont rhs).val[0]! + = lift_fe_mont (rhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_rhs_1 : (lift_chunk_mont rhs).val[1]! + = lift_fe_mont (rhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 1 (by rw [h_l]; decide)] + rw [h_lcm_lhs_0, h_lcm_lhs_1, h_lcm_rhs_0, h_lcm_rhs_1] + · -- Lane 1: touched by call 0 (zeta0, odd). + have h_r7_at_lane : r7.val[1]! = r0.val[1]! := by + rw [h_r7_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 1 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_fe := h_r0_fe_o + simp only [] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[1]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[1]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[1]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[0]!) + ((lift_chunk_mont rhs).val[1]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[1]!) + ((lift_chunk_mont rhs).val[0]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_0 : (lift_chunk_mont lhs).val[0]! + = lift_fe_mont (lhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_lhs_1 : (lift_chunk_mont lhs).val[1]! + = lift_fe_mont (lhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 1 (by rw [h_l]; decide)] + have h_lcm_rhs_0 : (lift_chunk_mont rhs).val[0]! + = lift_fe_mont (rhs.elements.val[0]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[0]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 0 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 0 (by rw [h_l]; decide)] + have h_lcm_rhs_1 : (lift_chunk_mont rhs).val[1]! + = lift_fe_mont (rhs.elements.val[1]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[1]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 1 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 1 (by rw [h_l]; decide)] + rw [h_lcm_lhs_0, h_lcm_lhs_1, h_lcm_rhs_0, h_lcm_rhs_1] + · -- Lane 2: touched by call 1 (nzeta0, even). + have h_r7_at_lane : r7.val[2]! = r1.val[2]! := by + rw [h_r7_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 2 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r0.val[2]! = out.val[2]! := by + rw [h_r0_unc' 2 (by decide) (by decide) (by decide)] + have h_src_at_odd : r0.val[3]! = out.val[3]! := by + rw [h_r0_unc' 3 (by decide) (by decide) (by decide)] + have h_fe := h_r1_fe_e + simp only [] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n0_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[2]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[2]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[2]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[2]!) + ((lift_chunk_mont rhs).val[2]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[3]!) + ((lift_chunk_mont rhs).val[3]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta0))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_2 : (lift_chunk_mont lhs).val[2]! + = lift_fe_mont (lhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_lhs_3 : (lift_chunk_mont lhs).val[3]! + = lift_fe_mont (lhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 3 (by rw [h_l]; decide)] + have h_lcm_rhs_2 : (lift_chunk_mont rhs).val[2]! + = lift_fe_mont (rhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_rhs_3 : (lift_chunk_mont rhs).val[3]! + = lift_fe_mont (rhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 3 (by rw [h_l]; decide)] + rw [h_lcm_lhs_2, h_lcm_lhs_3, h_lcm_rhs_2, h_lcm_rhs_3] + · -- Lane 3: touched by call 1 (nzeta0, odd). + have h_r7_at_lane : r7.val[3]! = r1.val[3]! := by + rw [h_r7_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 3 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r0.val[2]! = out.val[2]! := by + rw [h_r0_unc' 2 (by decide) (by decide) (by decide)] + have h_src_at_odd : r0.val[3]! = out.val[3]! := by + rw [h_r0_unc' 3 (by decide) (by decide) (by decide)] + have h_fe := h_r1_fe_o + simp only [] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[3]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[3]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[3]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[2]!) + ((lift_chunk_mont rhs).val[3]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[3]!) + ((lift_chunk_mont rhs).val[2]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_2 : (lift_chunk_mont lhs).val[2]! + = lift_fe_mont (lhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_lhs_3 : (lift_chunk_mont lhs).val[3]! + = lift_fe_mont (lhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 3 (by rw [h_l]; decide)] + have h_lcm_rhs_2 : (lift_chunk_mont rhs).val[2]! + = lift_fe_mont (rhs.elements.val[2]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[2]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 2 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 2 (by rw [h_l]; decide)] + have h_lcm_rhs_3 : (lift_chunk_mont rhs).val[3]! + = lift_fe_mont (rhs.elements.val[3]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[3]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 3 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 3 (by rw [h_l]; decide)] + rw [h_lcm_lhs_2, h_lcm_lhs_3, h_lcm_rhs_2, h_lcm_rhs_3] + · -- Lane 4: touched by call 2 (zeta1, even). + have h_r7_at_lane : r7.val[4]! = r2.val[4]! := by + rw [h_r7_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + have h_src_at_odd : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + have h_fe := h_r2_fe_e + simp only [] at h_fe + rw [h_src_at_even] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[4]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[4]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[4]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[4]!) + ((lift_chunk_mont rhs).val[4]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[5]!) + ((lift_chunk_mont rhs).val[5]!)) + (lift_fe_mont zeta1)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_4 : (lift_chunk_mont lhs).val[4]! + = lift_fe_mont (lhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_lhs_5 : (lift_chunk_mont lhs).val[5]! + = lift_fe_mont (lhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 5 (by rw [h_l]; decide)] + have h_lcm_rhs_4 : (lift_chunk_mont rhs).val[4]! + = lift_fe_mont (rhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_rhs_5 : (lift_chunk_mont rhs).val[5]! + = lift_fe_mont (rhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 5 (by rw [h_l]; decide)] + rw [h_lcm_lhs_4, h_lcm_lhs_5, h_lcm_rhs_4, h_lcm_rhs_5] + · -- Lane 5: touched by call 2 (zeta1, odd). + have h_r7_at_lane : r7.val[5]! = r2.val[5]! := by + rw [h_r7_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r1.val[4]! = out.val[4]! := by + rw [h_r1_unc' 4 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 4 (by decide) (by decide) (by decide)] + have h_src_at_odd : r1.val[5]! = out.val[5]! := by + rw [h_r1_unc' 5 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 5 (by decide) (by decide) (by decide)] + have h_fe := h_r2_fe_o + simp only [] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[5]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[5]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[5]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[4]!) + ((lift_chunk_mont rhs).val[5]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[5]!) + ((lift_chunk_mont rhs).val[4]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_4 : (lift_chunk_mont lhs).val[4]! + = lift_fe_mont (lhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_lhs_5 : (lift_chunk_mont lhs).val[5]! + = lift_fe_mont (lhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 5 (by rw [h_l]; decide)] + have h_lcm_rhs_4 : (lift_chunk_mont rhs).val[4]! + = lift_fe_mont (rhs.elements.val[4]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[4]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 4 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 4 (by rw [h_l]; decide)] + have h_lcm_rhs_5 : (lift_chunk_mont rhs).val[5]! + = lift_fe_mont (rhs.elements.val[5]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[5]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 5 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 5 (by rw [h_l]; decide)] + rw [h_lcm_lhs_4, h_lcm_lhs_5, h_lcm_rhs_4, h_lcm_rhs_5] + · -- Lane 6: touched by call 3 (nzeta1, even). + have h_r7_at_lane : r7.val[6]! = r3.val[6]! := by + rw [h_r7_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + have h_src_at_odd : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + have h_fe := h_r3_fe_e + simp only [] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n1_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[6]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[6]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[6]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[6]!) + ((lift_chunk_mont rhs).val[6]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[7]!) + ((lift_chunk_mont rhs).val[7]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta1))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_6 : (lift_chunk_mont lhs).val[6]! + = lift_fe_mont (lhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_lhs_7 : (lift_chunk_mont lhs).val[7]! + = lift_fe_mont (lhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 7 (by rw [h_l]; decide)] + have h_lcm_rhs_6 : (lift_chunk_mont rhs).val[6]! + = lift_fe_mont (rhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_rhs_7 : (lift_chunk_mont rhs).val[7]! + = lift_fe_mont (rhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 7 (by rw [h_l]; decide)] + rw [h_lcm_lhs_6, h_lcm_lhs_7, h_lcm_rhs_6, h_lcm_rhs_7] + · -- Lane 7: touched by call 3 (nzeta1, odd). + have h_r7_at_lane : r7.val[7]! = r3.val[7]! := by + rw [h_r7_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r2.val[6]! = out.val[6]! := by + rw [h_r2_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 6 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 6 (by decide) (by decide) (by decide)] + have h_src_at_odd : r2.val[7]! = out.val[7]! := by + rw [h_r2_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 7 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 7 (by decide) (by decide) (by decide)] + have h_fe := h_r3_fe_o + simp only [] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[7]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[7]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[7]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[6]!) + ((lift_chunk_mont rhs).val[7]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[7]!) + ((lift_chunk_mont rhs).val[6]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_6 : (lift_chunk_mont lhs).val[6]! + = lift_fe_mont (lhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_lhs_7 : (lift_chunk_mont lhs).val[7]! + = lift_fe_mont (lhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 7 (by rw [h_l]; decide)] + have h_lcm_rhs_6 : (lift_chunk_mont rhs).val[6]! + = lift_fe_mont (rhs.elements.val[6]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[6]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 6 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 6 (by rw [h_l]; decide)] + have h_lcm_rhs_7 : (lift_chunk_mont rhs).val[7]! + = lift_fe_mont (rhs.elements.val[7]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[7]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 7 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 7 (by rw [h_l]; decide)] + rw [h_lcm_lhs_6, h_lcm_lhs_7, h_lcm_rhs_6, h_lcm_rhs_7] + · -- Lane 8: touched by call 4 (zeta2, even). + have h_r7_at_lane : r7.val[8]! = r4.val[8]! := by + rw [h_r7_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + have h_src_at_odd : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + have h_fe := h_r4_fe_e + simp only [] at h_fe + rw [h_src_at_even] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[8]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[8]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[8]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[8]!) + ((lift_chunk_mont rhs).val[8]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[9]!) + ((lift_chunk_mont rhs).val[9]!)) + (lift_fe_mont zeta2)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_8 : (lift_chunk_mont lhs).val[8]! + = lift_fe_mont (lhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_lhs_9 : (lift_chunk_mont lhs).val[9]! + = lift_fe_mont (lhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 9 (by rw [h_l]; decide)] + have h_lcm_rhs_8 : (lift_chunk_mont rhs).val[8]! + = lift_fe_mont (rhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_rhs_9 : (lift_chunk_mont rhs).val[9]! + = lift_fe_mont (rhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 9 (by rw [h_l]; decide)] + rw [h_lcm_lhs_8, h_lcm_lhs_9, h_lcm_rhs_8, h_lcm_rhs_9] + · -- Lane 9: touched by call 4 (zeta2, odd). + have h_r7_at_lane : r7.val[9]! = r4.val[9]! := by + rw [h_r7_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r3.val[8]! = out.val[8]! := by + rw [h_r3_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 8 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 8 (by decide) (by decide) (by decide)] + have h_src_at_odd : r3.val[9]! = out.val[9]! := by + rw [h_r3_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 9 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 9 (by decide) (by decide) (by decide)] + have h_fe := h_r4_fe_o + simp only [] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[9]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[9]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[9]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[8]!) + ((lift_chunk_mont rhs).val[9]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[9]!) + ((lift_chunk_mont rhs).val[8]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_8 : (lift_chunk_mont lhs).val[8]! + = lift_fe_mont (lhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_lhs_9 : (lift_chunk_mont lhs).val[9]! + = lift_fe_mont (lhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 9 (by rw [h_l]; decide)] + have h_lcm_rhs_8 : (lift_chunk_mont rhs).val[8]! + = lift_fe_mont (rhs.elements.val[8]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[8]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 8 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 8 (by rw [h_l]; decide)] + have h_lcm_rhs_9 : (lift_chunk_mont rhs).val[9]! + = lift_fe_mont (rhs.elements.val[9]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[9]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 9 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 9 (by rw [h_l]; decide)] + rw [h_lcm_lhs_8, h_lcm_lhs_9, h_lcm_rhs_8, h_lcm_rhs_9] + · -- Lane 10: touched by call 5 (nzeta2, even). + have h_r7_at_lane : r7.val[10]! = r5.val[10]! := by + rw [h_r7_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + have h_src_at_odd : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + have h_fe := h_r5_fe_e + simp only [] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n2_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[10]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[10]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[10]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[10]!) + ((lift_chunk_mont rhs).val[10]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[11]!) + ((lift_chunk_mont rhs).val[11]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta2))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_10 : (lift_chunk_mont lhs).val[10]! + = lift_fe_mont (lhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_lhs_11 : (lift_chunk_mont lhs).val[11]! + = lift_fe_mont (lhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 11 (by rw [h_l]; decide)] + have h_lcm_rhs_10 : (lift_chunk_mont rhs).val[10]! + = lift_fe_mont (rhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_rhs_11 : (lift_chunk_mont rhs).val[11]! + = lift_fe_mont (rhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 11 (by rw [h_l]; decide)] + rw [h_lcm_lhs_10, h_lcm_lhs_11, h_lcm_rhs_10, h_lcm_rhs_11] + · -- Lane 11: touched by call 5 (nzeta2, odd). + have h_r7_at_lane : r7.val[11]! = r5.val[11]! := by + rw [h_r7_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r6_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r4.val[10]! = out.val[10]! := by + rw [h_r4_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 10 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 10 (by decide) (by decide) (by decide)] + have h_src_at_odd : r4.val[11]! = out.val[11]! := by + rw [h_r4_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 11 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 11 (by decide) (by decide) (by decide)] + have h_fe := h_r5_fe_o + simp only [] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[11]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[11]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[11]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[10]!) + ((lift_chunk_mont rhs).val[11]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[11]!) + ((lift_chunk_mont rhs).val[10]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_10 : (lift_chunk_mont lhs).val[10]! + = lift_fe_mont (lhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_lhs_11 : (lift_chunk_mont lhs).val[11]! + = lift_fe_mont (lhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 11 (by rw [h_l]; decide)] + have h_lcm_rhs_10 : (lift_chunk_mont rhs).val[10]! + = lift_fe_mont (rhs.elements.val[10]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[10]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 10 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 10 (by rw [h_l]; decide)] + have h_lcm_rhs_11 : (lift_chunk_mont rhs).val[11]! + = lift_fe_mont (rhs.elements.val[11]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[11]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 11 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 11 (by rw [h_l]; decide)] + rw [h_lcm_lhs_10, h_lcm_lhs_11, h_lcm_rhs_10, h_lcm_rhs_11] + · -- Lane 12: touched by call 6 (zeta3, even). + have h_r7_at_lane : r7.val[12]! = r6.val[12]! := by + rw [h_r7_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + have h_src_at_odd : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + have h_fe := h_r6_fe_e + simp only [] at h_fe + rw [h_src_at_even] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[12]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[12]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[12]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[12]!) + ((lift_chunk_mont rhs).val[12]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[13]!) + ((lift_chunk_mont rhs).val[13]!)) + (lift_fe_mont zeta3)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_12 : (lift_chunk_mont lhs).val[12]! + = lift_fe_mont (lhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_lhs_13 : (lift_chunk_mont lhs).val[13]! + = lift_fe_mont (lhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 13 (by rw [h_l]; decide)] + have h_lcm_rhs_12 : (lift_chunk_mont rhs).val[12]! + = lift_fe_mont (rhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_rhs_13 : (lift_chunk_mont rhs).val[13]! + = lift_fe_mont (rhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 13 (by rw [h_l]; decide)] + rw [h_lcm_lhs_12, h_lcm_lhs_13, h_lcm_rhs_12, h_lcm_rhs_13] + · -- Lane 13: touched by call 6 (zeta3, odd). + have h_r7_at_lane : r7.val[13]! = r6.val[13]! := by + rw [h_r7_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r7_at_lane] + have h_src_at_even : r5.val[12]! = out.val[12]! := by + rw [h_r5_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 12 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 12 (by decide) (by decide) (by decide)] + have h_src_at_odd : r5.val[13]! = out.val[13]! := by + rw [h_r5_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 13 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 13 (by decide) (by decide) (by decide)] + have h_fe := h_r6_fe_o + simp only [] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[13]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[13]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[13]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[12]!) + ((lift_chunk_mont rhs).val[13]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[13]!) + ((lift_chunk_mont rhs).val[12]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_12 : (lift_chunk_mont lhs).val[12]! + = lift_fe_mont (lhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_lhs_13 : (lift_chunk_mont lhs).val[13]! + = lift_fe_mont (lhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 13 (by rw [h_l]; decide)] + have h_lcm_rhs_12 : (lift_chunk_mont rhs).val[12]! + = lift_fe_mont (rhs.elements.val[12]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[12]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 12 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 12 (by rw [h_l]; decide)] + have h_lcm_rhs_13 : (lift_chunk_mont rhs).val[13]! + = lift_fe_mont (rhs.elements.val[13]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[13]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 13 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 13 (by rw [h_l]; decide)] + rw [h_lcm_lhs_12, h_lcm_lhs_13, h_lcm_rhs_12, h_lcm_rhs_13] + · -- Lane 14: touched by call 7 (nzeta3, even). + have h_src_at_even : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + have h_src_at_odd : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + have h_fe := h_r7_fe_e + simp only [] at h_fe + rw [h_src_at_even] at h_fe + rw [h_n3_fe] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[14]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[14]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[14]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[14]!) + ((lift_chunk_mont rhs).val[14]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[15]!) + ((lift_chunk_mont rhs).val[15]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe_mont zeta3))) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_14 : (lift_chunk_mont lhs).val[14]! + = lift_fe_mont (lhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_lhs_15 : (lift_chunk_mont lhs).val[15]! + = lift_fe_mont (lhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 15 (by rw [h_l]; decide)] + have h_lcm_rhs_14 : (lift_chunk_mont rhs).val[14]! + = lift_fe_mont (rhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_rhs_15 : (lift_chunk_mont rhs).val[15]! + = lift_fe_mont (rhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 15 (by rw [h_l]; decide)] + rw [h_lcm_lhs_14, h_lcm_lhs_15, h_lcm_rhs_14, h_lcm_rhs_15] + · -- Lane 15: touched by call 7 (nzeta3, odd). + have h_src_at_even : r6.val[14]! = out.val[14]! := by + rw [h_r6_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 14 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 14 (by decide) (by decide) (by decide)] + have h_src_at_odd : r6.val[15]! = out.val[15]! := by + rw [h_r6_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r5_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r4_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r3_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r2_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r1_unc' 15 (by decide) (by decide) (by decide)] + rw [h_r0_unc' 15 (by decide) (by decide) (by decide)] + have h_fe := h_r7_fe_o + simp only [] at h_fe + rw [h_src_at_odd] at h_fe + rw [h_fe] + have h_red_out : (Spec.chunk_reducing_from_i32_array_pure out).val[15]! + = Spec.mont_reduce_pure (lift_fe_int (out.val[15]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + rfl + rw [h_red_out] + have h_red_no_acc : (Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont lhs) (lift_chunk_mont rhs) + (lift_fe_mont zeta0) (lift_fe_mont zeta1) + (lift_fe_mont zeta2) (lift_fe_mont zeta3)).val[15]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[14]!) + ((lift_chunk_mont rhs).val[15]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk_mont lhs).val[15]!) + ((lift_chunk_mont rhs).val[14]!)) := by + unfold Spec.ntt_multiply_pure_no_acc + rfl + rw [h_red_no_acc] + have h_lcm_lhs_14 : (lift_chunk_mont lhs).val[14]! + = lift_fe_mont (lhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_lhs_15 : (lift_chunk_mont lhs).val[15]! + = lift_fe_mont (lhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + show (lhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (lhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (lhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val 15 (by rw [h_l]; decide)] + have h_lcm_rhs_14 : (lift_chunk_mont rhs).val[14]! + = lift_fe_mont (rhs.elements.val[14]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[14]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 14 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 14 (by rw [h_l]; decide)] + have h_lcm_rhs_15 : (lift_chunk_mont rhs).val[15]! + = lift_fe_mont (rhs.elements.val[15]!) := by + unfold lift_chunk_mont + have h_l : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + show (rhs.elements.val.map lift_fe_mont)[15]! = _ + have h_ml : (rhs.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_l + rw [getElem!_pos (rhs.elements.val.map lift_fe_mont) 15 (by rw [h_ml]; decide)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val 15 (by rw [h_l]; decide)] + rw [h_lcm_lhs_14, h_lcm_lhs_15, h_lcm_rhs_14, h_lcm_rhs_15] + + +/-- Algebraic POST predicate for the L6.3 polynomial-level NTT + multiply. Relates the resulting I32 accumulator array `r` to the + polynomial inputs and the initial accumulator state via a per-chunk + × per-lane equation in Mont-domain `FieldElement` space: + + `mont_reduce_pure (lift_fe_int r[16j+ℓ])` + `= mont_reduce_pure (lift_fe_int accumulator[16j+ℓ])` + `+ no_acc_product j ℓ` + + where the per-chunk product `no_acc_product j ℓ` is the ℓ-th lane + of `Spec.ntt_multiply_pure_no_acc` applied to Mont-domain lifts of + the j-th coefficient vectors and the four zetas at + `Spec.zeta_at (64 + 4j + {0,1,2,3})`. + + Composes the L2.8 per-chunk POST (`ntt_multiply_base_case_post`) + via the chunk_add_pure decomposition baked into + `ntt_multiply_base_case_alg`: at each j ∈ 0..15, applying L2.8 to + the 16-lane window `[16j..16(j+1)]` gives the per-chunk equation + `chunk_reducing_from_i32_array_pure r_chunk = + chunk_add_pure (chunk_reducing_from_i32_array_pure acc_chunk) + (ntt_multiply_pure_no_acc ...)`, + which extracts per-lane to the equation above (since + `chunk_reducing_from_i32_array_pure x .val[ℓ] = mont_reduce_pure + (lift_fe_int (x.val[ℓ]!).val)` definitionally). -/ +noncomputable def accumulating_ntt_multiply_poly_post + (myself rhs : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (accumulator r : Std.Array Std.I32 256#usize) : Prop := + ∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (r.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (accumulator.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!) + +namespace UseCacheFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Step-local accumulator: 256-lane `I32` array. -/ +abbrev Acc := Std.Array Std.I32 256#usize + +abbrev Poly := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `accumulating_ntt_multiply_loop`. Per-lane: + * (a) Touched chunks (`j < k`): per-lane FC equation against L2.8's + `Spec.ntt_multiply_pure_no_acc` plus initial accumulator. + * (b) Chunks `j ≥ k`: per-lane unchanged from `acc_init`. + * (c) Universal: per-lane bound — `|acc[n]| ≤ |acc_init[n]| + 2^25` + for touched lanes; `acc[n] = acc_init[n]` (hence bound 0) for + untouched. We encode the bound directly over all lanes since + (c) ⇒ touched case bound. -/ +def inv (myself rhs : Poly) (acc_init : Acc) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → ∀ ℓ : Nat, ℓ < 16 → + acc.val[16 * j + ℓ]! = acc_init.val[16 * j + ℓ]!) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + 2^25)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post (myself rhs : Poly) (acc_init : Acc) (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv myself rhs acc_init iter'.start acc').holds + | .done y => (inv myself rhs acc_init 16#usize y).holds + +end UseCacheFC + +/-- Array sub-slice extraction via `index_mut` over a `Range Usize`, + in `.ok`-form. Mirrors `slice_index_range_ok_eq_fc` but for + `Std.Array`: routes through `Array.to_slice_mut` to obtain a + sub-slice `s` plus a write-back closure satisfying + `(back s').val = a.val.setSlice! r.start.val s'.val` whenever `s'` + has length `r.end.val - r.start.val`. -/ +theorem array_index_mut_range_ok_eq_fc + {T : Type} {N : Std.Usize} (a : Std.Array T N) + (r : CoreModels.core.ops.range.Range Std.Usize) + (h0 : r.start.val ≤ r.end.val) (h1 : r.end.val ≤ a.val.length) : + ∃ (s : Slice T) (back : Slice T → Std.Array T N), + core.Array.Insts.CoreOpsIndexIndexMut.index_mut + (core.Slice.Insts.CoreOpsIndexIndexMut + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice T)) + a { start := r.start, «end» := r.end } = .ok (s, back) + ∧ s.val = a.val.slice r.start.val r.end.val + ∧ s.val.length = r.end.val - r.start.val + ∧ (∀ s' : Slice T, s'.val.length = r.end.val - r.start.val → + (back s').val = a.val.setSlice! r.start.val s'.val) := by + -- Unfold the Array-level index_mut to the to_slice_mut + slice index_mut composition. + set a_slice : Slice T := Aeneas.Std.Array.to_slice a with ha_slice_def + have h_a_slice_val : a_slice.val = a.val := + Aeneas.Std.Array.val_to_slice a + have h_a_slice_len : a_slice.val.length = a.val.length := by rw [h_a_slice_val] + have h1' : r.end.val ≤ a_slice.val.length := by rw [h_a_slice_len]; exact h1 + -- Slice-level index_mut over the same range. + have hT := libcrux_iot_ml_kem.Util.SliceSpecs.core_models_Slice_Insts_index_mut_RangeUsize_spec + (T := T) a_slice + ({ start := r.start, «end» := r.end } : CoreModels.core.ops.range.Range Std.Usize) + h0 h1' + obtain ⟨p, h_p_eq, h_p_post⟩ := triple_exists_ok_fc hT + obtain ⟨h_p_val, h_p_len, h_p_back⟩ := h_p_post + -- The Array-level closure: fun o => Array.from_slice a (slice_back o). + refine ⟨p.1, fun o => Aeneas.Std.Array.from_slice a (p.2 o), ?_, ?_, ?_, ?_⟩ + · -- The Array index_mut reduces to `do let (s, back) ← Slice.index_mut ...; ok (s, ...)`. + unfold core.Array.Insts.CoreOpsIndexIndexMut.index_mut + -- to_slice_mut := (to_slice a, from_slice a). + show (do + let p ← core.Slice.Insts.CoreOpsIndexIndexMut.index_mut + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice T) + a_slice { start := r.start, «end» := r.end } + .ok (p.1, fun o => Aeneas.Std.Array.from_slice a (p.2 o))) + = .ok (p.1, fun o => Aeneas.Std.Array.from_slice a (p.2 o)) + rw [h_p_eq]; rfl + · -- Sub-slice val. + rw [h_p_val]; rw [h_a_slice_val] + · -- Sub-slice length. + exact h_p_len + · -- Write-back: `(from_slice a (slice_back s')).val = a.val.setSlice! r.start.val s'.val`. + intro s' hs'_len + have h_back_val : (p.2 s').val = a_slice.val.setSlice! r.start.val s'.val := h_p_back s' + have h_back_len : (p.2 s').val.length = N.val := by + rw [h_back_val, h_a_slice_val, List.length_setSlice!] + exact Std.Array.length_eq a + have h_from_slice_val : + (Aeneas.Std.Array.from_slice a (p.2 s')).val = (p.2 s').val := + Aeneas.Std.Array.from_slice_val a (p.2 s') h_back_len + rw [h_from_slice_val, h_back_val, h_a_slice_val] + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for + `polynomial.PolynomialRingElement.accumulating_ntt_multiply`. -/ +theorem accumulating_ntt_multiply_poly_step_lemma_fc + (myself rhs : UseCacheFC.Poly) (acc_init : UseCacheFC.Acc) + (h_self : ∀ i : Fin 16, ∀ j : Fin 16, + ((myself.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_rhs : ∀ i : Fin 16, ∀ j : Fin 16, + ((rhs.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, (acc_init.val[n.val]!).val.natAbs ≤ 2^30) + (acc : UseCacheFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (UseCacheFC.inv myself rhs acc_init k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ UseCacheFC.step_post myself rhs acc_init k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_acc_len : acc.val.length = 256 := + (Std.Array.length_eq acc) + have h_acc_init_len : acc_init.val.length = 256 := + (Std.Array.length_eq acc_init) + have h_self_coef_len : myself.coefficients.length = 16 := + Std.Array.length_eq _ + have h_rhs_coef_len : rhs.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_acc_done, h_acc_undone, h_acc_bnd_rel⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some i = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s_iter, hs_val_eq, h_iter_some⟩ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) t := self.coefficients[k] and t1 := rhs.coefficients[k]. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + myself.coefficients.val[k.val]! with ht_def + set t1 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + rhs.coefficients.val[k.val]! with ht1_def + have h_idx_t : Aeneas.Std.Array.index_usize myself.coefficients k = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq myself.coefficients k + (by rw [h_self_coef_len]; exact hk_16) + have h_idx_t1 : Aeneas.Std.Array.index_usize rhs.coefficients k = .ok t1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq rhs.coefficients k + (by rw [h_rhs_coef_len]; exact hk_16) + -- (2) i1 := k * 16, i2 := k + 1, i3 := i2 * 16. + have hi1_max : k.val * (16#usize : Std.Usize).val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hum] + have h1 : k.val * 16 ≤ 15 * 16 := Nat.mul_le_mul_right 16 hk_15 + have : (15 * 16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i1, hi1_eq, hi1_val⟩ := usize_mul_ok_eq_fc k 16#usize hi1_max + have hi1_val_eq : i1.val = 16 * k.val := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hi1_val, hum]; omega + have hi2_max : k.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hum] + have : (16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i2, hi2_eq, hi2_val⟩ := usize_add_ok_eq_fc k 1#usize hi2_max + have hi2_val_eq : i2.val = k.val + 1 := by + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hi2_val, hum] + have hi3_max : i2.val * (16#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hum, hi2_val_eq] + have : k.val + 1 ≤ 16 := by omega + have h1 : (k.val + 1) * 16 ≤ 16 * 16 := Nat.mul_le_mul_right 16 this + have : (16 * 16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i3, hi3_eq, hi3_val⟩ := usize_mul_ok_eq_fc i2 16#usize hi3_max + have hi3_val_eq : i3.val = 16 * (k.val + 1) := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hi3_val, hi2_val_eq, hum]; omega + -- (3) Sub-slice via Array index_mut RangeUsize. + have h0_le : i1.val ≤ i3.val := by rw [hi1_val_eq, hi3_val_eq]; omega + have hi3_le : i3.val ≤ acc.val.length := by + rw [h_acc_len, hi3_val_eq] + have : k.val + 1 ≤ 16 := by omega + have h1 : 16 * (k.val + 1) ≤ 16 * 16 := Nat.mul_le_mul_left _ this + omega + obtain ⟨s, back, h_imt_eq, h_s_val_eq, h_s_len_eq, h_back_eq⟩ := + array_index_mut_range_ok_eq_fc acc + ({ start := i1, «end» := i3 } : CoreModels.core.ops.range.Range Std.Usize) + h0_le hi3_le + have h_s_len16 : s.length = 16 := by + show s.val.length = 16 + rw [h_s_len_eq] + show i3.val - i1.val = 16 + rw [hi3_val_eq, hi1_val_eq]; omega + -- Per-lane lookup: s.val[ℓ]! = acc.val[16*k + ℓ]!. + have h_s_lane : ∀ ℓ : Nat, ℓ < 16 → + s.val[ℓ]! = acc.val[16 * k.val + ℓ]! := by + intro ℓ hℓ + rw [h_s_val_eq] + have h_idx_lt : i1.val + ℓ < i3.val := by + rw [hi1_val_eq, hi3_val_eq]; omega + have h_bnd : i3.val ≤ acc.val.length ∧ i1.val + ℓ < i3.val := ⟨hi3_le, h_idx_lt⟩ + rw [List.getElem!_slice i1.val i3.val ℓ acc.val h_bnd] + rw [hi1_val_eq] + -- The lanes [16k, 16(k+1)) are untouched in `acc` (j ≥ k). + have h_s_lane_init : ∀ ℓ : Nat, ℓ < 16 → + s.val[ℓ]! = acc_init.val[16 * k.val + ℓ]! := by + intro ℓ hℓ + rw [h_s_lane ℓ hℓ] + exact h_acc_undone k.val (Nat.le_refl _) hk_16 ℓ hℓ + -- (4) Per-lane bound on `s` (≤ 2^30 from h_acc_bnd). + have h_s_bnd : ∀ k' : Fin 16, (s.val[k'.val]!).val.natAbs ≤ 2^30 := by + intro k' + rw [h_s_lane_init k'.val k'.isLt] + have h_lt : 16 * k.val + k'.val < 256 := by + have : k.val ≤ 15 := by omega + have hk' : k'.val < 16 := k'.isLt + have : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + exact h_acc_bnd ⟨16 * k.val + k'.val, h_lt⟩ + -- (5) Read 4 zetas via polynomial.zeta_eq_ok_fc. + have hi4_max : (4#usize : Std.Usize).val * k.val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (4#usize : Std.Usize).val = 4 := rfl + rw [hum] + have : 4 * k.val ≤ 4 * 15 := Nat.mul_le_mul_left 4 hk_15 + have : (4 * 15 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i4, hi4_eq, hi4_val⟩ := usize_mul_ok_eq_fc 4#usize k hi4_max + have hi4_val_eq : i4.val = 4 * k.val := by + have hum : (4#usize : Std.Usize).val = 4 := rfl + rw [hi4_val, hum] + have hi5_max : (64#usize : Std.Usize).val + i4.val ≤ Std.Usize.max := by + have hum : (64#usize : Std.Usize).val = 64 := rfl + rw [hum, hi4_val_eq] + have hk_15 : k.val ≤ 15 := by omega + have : 4 * k.val ≤ 4 * 15 := Nat.mul_le_mul_left 4 hk_15 + have : (64 + 4 * 15 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + -- i5..i14 are 4 add-then-zeta sequences. + obtain ⟨i5, hi5_eq, hi5_val⟩ := usize_add_ok_eq_fc 64#usize i4 hi5_max + have hi5_val_eq : i5.val = 64 + 4 * k.val := by + have hum : (64#usize : Std.Usize).val = 64 := rfl + rw [hi5_val, hum, hi4_val_eq] + have hi5_lt_128 : i5.val < 128 := by rw [hi5_val_eq]; omega + obtain ⟨z0, hz0_eq, hz0_post⟩ := + triple_exists_ok_fc (polynomial.zeta_fc i5 hi5_lt_128) + obtain ⟨hz0_val_eq, hz0_bnd, hz0_lift⟩ := hz0_post + -- i8 := i5 + 1 (since i7 = i5 after the duplicate `64 + i4` rewrite). + have hi8_max : i5.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hum, hi5_val_eq] + have hk_15 : k.val ≤ 15 := by omega + have : (64 + 4 * 15 + 1 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i8, hi8_eq, hi8_val⟩ := usize_add_ok_eq_fc i5 1#usize hi8_max + have hi8_val_eq : i8.val = 64 + 4 * k.val + 1 := by + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hi8_val, hi5_val_eq, hum] + have hi8_lt_128 : i8.val < 128 := by rw [hi8_val_eq]; omega + obtain ⟨z1, hz1_eq, hz1_post⟩ := + triple_exists_ok_fc (polynomial.zeta_fc i8 hi8_lt_128) + obtain ⟨hz1_val_eq, hz1_bnd, hz1_lift⟩ := hz1_post + have hi11_max : i5.val + (2#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (2#usize : Std.Usize).val = 2 := rfl + rw [hum, hi5_val_eq] + have hk_15 : k.val ≤ 15 := by omega + have : (64 + 4 * 15 + 2 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i11, hi11_eq, hi11_val⟩ := usize_add_ok_eq_fc i5 2#usize hi11_max + have hi11_val_eq : i11.val = 64 + 4 * k.val + 2 := by + have hum : (2#usize : Std.Usize).val = 2 := rfl + rw [hi11_val, hi5_val_eq, hum] + have hi11_lt_128 : i11.val < 128 := by rw [hi11_val_eq]; omega + obtain ⟨z2, hz2_eq, hz2_post⟩ := + triple_exists_ok_fc (polynomial.zeta_fc i11 hi11_lt_128) + obtain ⟨hz2_val_eq, hz2_bnd, hz2_lift⟩ := hz2_post + have hi14_max : i5.val + (3#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (3#usize : Std.Usize).val = 3 := rfl + rw [hum, hi5_val_eq] + have hk_15 : k.val ≤ 15 := by omega + have : (64 + 4 * 15 + 3 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i14, hi14_eq, hi14_val⟩ := usize_add_ok_eq_fc i5 3#usize hi14_max + have hi14_val_eq : i14.val = 64 + 4 * k.val + 3 := by + have hum : (3#usize : Std.Usize).val = 3 := rfl + rw [hi14_val, hi5_val_eq, hum] + have hi14_lt_128 : i14.val < 128 := by rw [hi14_val_eq]; omega + obtain ⟨z3, hz3_eq, hz3_post⟩ := + triple_exists_ok_fc (polynomial.zeta_fc i14 hi14_lt_128) + obtain ⟨hz3_val_eq, hz3_bnd, hz3_lift⟩ := hz3_post + -- (6) Apply L2.8 to get s1 satisfying ntt_multiply_base_case_post + bound. + have h_t_lhs : ∀ j : Fin 16, (t.elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro j + exact h_self ⟨k.val, hk_16⟩ j + have h_t1_rhs : ∀ j : Fin 16, (t1.elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro j + exact h_rhs ⟨k.val, hk_16⟩ j + obtain ⟨s1, h_s1_eq, h_s1_len, h_s1_bnd, h_s1_post⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_fc t t1 s z0 z1 z2 z3 h_s_len16 + h_t_lhs h_t1_rhs hz0_bnd hz1_bnd hz2_bnd hz3_bnd h_s_bnd) + -- s1's bound vs s lanes (s.val[k'] = acc_init[16k+k']). + have h_s1_bnd_abs : ∀ k' : Nat, k' < 16 → + (s1.val[k']!).val.natAbs ≤ (acc_init.val[16 * k.val + k']!).val.natAbs + 2^25 := by + intro k' hk' + have h_step_bnd := h_s1_bnd ⟨k', hk'⟩ + simp only at h_step_bnd + rw [h_s_lane_init k' hk'] at h_step_bnd + exact h_step_bnd + -- (7) Compose acc1 := back s1. + set acc1 : UseCacheFC.Acc := back s1 with hacc1_def + have h_acc1_val : acc1.val = acc.val.setSlice! i1.val s1.val := + h_back_eq s1 (by show s1.val.length = i3.val - i1.val; rw [← h_s_len_eq]; + show s1.length = s.length; rw [h_s1_len, h_s_len16]) + have h_acc1_len : acc1.val.length = 256 := by + rw [h_acc1_val, List.length_setSlice!, h_acc_len] + -- (8) Per-lane lookup of acc1 in the touched window: acc1[16k+ℓ] = s1[ℓ]. + have h_acc1_in : ∀ ℓ : Nat, ℓ < 16 → + acc1.val[16 * k.val + ℓ]! = s1.val[ℓ]! := by + intro ℓ hℓ + rw [h_acc1_val] + have h_get : (acc.val.setSlice! i1.val s1.val)[16 * k.val + ℓ]! + = s1.val[(16 * k.val + ℓ) - i1.val]! := by + apply List.getElem!_setSlice!_middle + refine ⟨?_, ?_, ?_⟩ + · rw [hi1_val_eq]; omega + · rw [hi1_val_eq] + have h_sub' : 16 * k.val + ℓ - 16 * k.val = ℓ := by omega + rw [h_sub'] + show ℓ < s1.val.length + have h_s1' : s1.val.length = 16 := h_s1_len + rw [h_s1']; exact hℓ + · rw [h_acc_len] + have hk_15' : k.val ≤ 15 := by omega + have h1 : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 hk_15' + omega + rw [h_get] + have h_sub : (16 * k.val + ℓ) - i1.val = ℓ := by + rw [hi1_val_eq]; omega + rw [h_sub] + -- Outside the window: acc1[n] = acc[n]. + have h_acc1_out : ∀ n : Nat, n < 256 → + (n < 16 * k.val ∨ 16 * (k.val + 1) ≤ n) → + acc1.val[n]! = acc.val[n]! := by + intro n hn hcases + rw [h_acc1_val] + rcases hcases with hlt | hge + · -- n < 16 * k.val = i1.val. + apply List.getElem!_setSlice!_prefix + rw [hi1_val_eq]; exact hlt + · -- 16 * (k.val + 1) ≤ n. + apply List.getElem!_setSlice!_suffix + rw [hi1_val_eq] + have h_s1' : s1.val.length = 16 := h_s1_len + rw [h_s1'] + have h_eq16 : 16 * k.val + 16 = 16 * (k.val + 1) := by ring + rw [h_eq16]; exact hge + -- (9) Body equation: the impl reduces to .ok (cont (..., acc1)). + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc1)) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let t' ← Aeneas.Std.Array.index_usize myself.coefficients k + let t1' ← Aeneas.Std.Array.index_usize rhs.coefficients k + let i1' ← (k * 16#usize : Result Std.Usize) + let i2' ← k + 1#usize + let i3' ← i2' * 16#usize + let (s', index_mut_back) ← + core.Array.Insts.CoreOpsIndexIndexMut.index_mut + (core.Slice.Insts.CoreOpsIndexIndexMut + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + Std.I32)) acc { start := i1', «end» := i3' } + let i4' ← 4#usize * k + let i5' ← 64#usize + i4' + let i6' ← libcrux_iot_ml_kem.polynomial.zeta i5' + let i7' ← 64#usize + i4' + let i8' ← i7' + 1#usize + let i9' ← libcrux_iot_ml_kem.polynomial.zeta i8' + let i10' ← 64#usize + i4' + let i11' ← i10' + 2#usize + let i12' ← libcrux_iot_ml_kem.polynomial.zeta i11' + let i13' ← 64#usize + i4' + let i14' ← i13' + 3#usize + let i15' ← libcrux_iot_ml_kem.polynomial.zeta i14' + let s1' ← + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.accumulating_ntt_multiply + t' t1' s' i6' i9' i12' i15' + .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), index_mut_back s1'))) + : Result _) = _ + rw [h_idx_t]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_t1]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi4_eq]; simp only [Aeneas.Std.bind_tc_ok] + -- rw [hi5_eq] rewrites all four occurrences of `64#usize + i4` to `.ok i5`. + rw [hi5_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hz0_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi8_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hz1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi11_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hz2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi14_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hz3_eq]; simp only [Aeneas.Std.bind_tc_ok] + -- Now we have `vectortraitsOperationsInst.accumulating_ntt_multiply t t1 s z0 z1 z2 z3`. + -- For portable_ops_inst, this reduces definitionally to vector.portable.ntt.acc_ntt_mult. + show ((do + let s1' ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply + t t1 s z0 z1 z2 z3 + .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), back s1'))) + : Result _) = _ + rw [h_s1_eq] + rfl + apply triple_of_ok_fc h_body + show UseCacheFC.step_post myself rhs acc_init k + (.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc1)) + unfold UseCacheFC.step_post + refine ⟨h_lt, rfl, hs_val_eq, ?_⟩ + show (UseCacheFC.inv myself rhs acc_init s_iter acc1).holds + -- Build the three invariant conjuncts. + have h_inv_pure : + (∀ j : Nat, j < s_iter.val → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + ∧ (∀ j : Nat, s_iter.val ≤ j → j < 16 → ∀ ℓ : Nat, ℓ < 16 → + acc1.val[16 * j + ℓ]! = acc_init.val[16 * j + ℓ]!) + ∧ (∀ n : Nat, n < 256 → + (acc1.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + 2^25) := by + refine ⟨?_, ?_, ?_⟩ + · -- (a) Touched-chunk FC. + intro j hj ℓ hℓ + rw [hs_val_eq] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k: chunk unchanged in acc, FC from inv. + have h_in_range : 16 * j + ℓ < 16 * k.val := by + have h1 : 16 * j + 16 ≤ 16 * k.val := by + have : j + 1 ≤ k.val := by omega + have : 16 * (j + 1) ≤ 16 * k.val := Nat.mul_le_mul_left 16 this + omega + omega + have h_lt_256 : 16 * j + ℓ < 256 := by + have : k.val ≤ 15 := by omega + have : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + have h_acc1_eq_acc : acc1.val[16 * j + ℓ]! = acc.val[16 * j + ℓ]! := + h_acc1_out (16 * j + ℓ) h_lt_256 (Or.inl h_in_range) + rw [h_acc1_eq_acc] + exact h_acc_done j hj_lt_k ℓ hℓ + · -- j = k: chunk = s1; unfold L2.8 POST. + subst hj_eq_k + rw [h_acc1_in ℓ hℓ] + -- h_s1_post : ntt_multiply_base_case_post t t1 z0 z1 z2 z3 s s1. + -- = Spec.chunk_reducing_from_i32_array_pure s1 + -- = ntt_multiply_base_case_alg ... (Spec.chunk_reducing_from_i32_array_pure s). + unfold ntt_multiply_base_case_post at h_s1_post + -- Per-lane unfold: (chunk_reducing_from_i32_array_pure s1).val[ℓ] + -- = mont_reduce_pure (lift_fe_int s1.val[ℓ]!.val). + -- And ntt_multiply_base_case_alg = chunk_add_pure ... (...). + have h_lhs_val_eq : + (Spec.chunk_reducing_from_i32_array_pure s1).val[ℓ]! + = Spec.mont_reduce_pure (lift_fe_int (s1.val[ℓ]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + show ((List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (s1.val[i]!).val)))[ℓ]! = _ + have h_len : ((List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (s1.val[i]!).val))).length = 16 := by simp + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map, List.getElem_range] + have h_rhs_val_eq : + (ntt_multiply_base_case_alg + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3) + (Spec.chunk_reducing_from_i32_array_pure s)).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.chunk_reducing_from_i32_array_pure s).val[ℓ]!) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3)).val[ℓ]!) := by + unfold ntt_multiply_base_case_alg Spec.chunk_add_pure + show ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.chunk_reducing_from_i32_array_pure s).val[i]!) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3)).val[i]!)))[ℓ]! = _ + have h_len : ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.chunk_reducing_from_i32_array_pure s).val[i]!) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3)).val[i]!))).length = 16 := by simp + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map, List.getElem_range] + have h_s_chunk_val : + (Spec.chunk_reducing_from_i32_array_pure s).val[ℓ]! + = Spec.mont_reduce_pure (lift_fe_int (s.val[ℓ]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + show ((List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (s.val[i]!).val)))[ℓ]! = _ + have h_len : ((List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (s.val[i]!).val))).length = 16 := by simp + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map, List.getElem_range] + have h_chunk_eq : + Spec.mont_reduce_pure (lift_fe_int (s1.val[ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (s.val[ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3)).val[ℓ]!) := by + have h_eq : (Spec.chunk_reducing_from_i32_array_pure s1).val[ℓ]! + = (ntt_multiply_base_case_alg + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3) + (Spec.chunk_reducing_from_i32_array_pure s)).val[ℓ]! := by + rw [h_s1_post] + rw [h_lhs_val_eq] at h_eq + rw [h_rhs_val_eq] at h_eq + rw [h_s_chunk_val] at h_eq + exact h_eq + rw [h_chunk_eq] + -- Now substitute s.val[ℓ]! = acc_init.val[16*k+ℓ]!. + rw [h_s_lane_init ℓ hℓ] + -- Match zeta indices. + rw [hz0_lift, hz1_lift, hz2_lift, hz3_lift] + rw [hi5_val_eq, hi8_val_eq, hi11_val_eq, hi14_val_eq] + · -- (b) Untouched chunks: j ≥ k+1. + intro j hj_ge hj_lt ℓ hℓ + rw [hs_val_eq] at hj_ge + have h_n_lt_256 : 16 * j + ℓ < 256 := by + have : j ≤ 15 := by omega + have : 16 * j ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + have h_ge_range : 16 * (k.val + 1) ≤ 16 * j + ℓ := by + have : k.val + 1 ≤ j := hj_ge + have : 16 * (k.val + 1) ≤ 16 * j := Nat.mul_le_mul_left 16 this + omega + have h_acc1_eq_acc : acc1.val[16 * j + ℓ]! = acc.val[16 * j + ℓ]! := + h_acc1_out (16 * j + ℓ) h_n_lt_256 (Or.inr h_ge_range) + rw [h_acc1_eq_acc] + exact h_acc_undone j (by omega) hj_lt ℓ hℓ + · -- (c) Universal bound. + intro n hn + by_cases hcase : 16 * k.val ≤ n ∧ n < 16 * (k.val + 1) + · -- Inside the touched window. + obtain ⟨hge, hlt⟩ := hcase + have hn_decomp : n = 16 * k.val + (n - 16 * k.val) := by omega + have hn_off_lt : n - 16 * k.val < 16 := by omega + have h_acc1_n : acc1.val[n]! = s1.val[n - 16 * k.val]! := by + conv_lhs => rw [hn_decomp] + exact h_acc1_in (n - 16 * k.val) hn_off_lt + rw [h_acc1_n] + have h_bnd_at_off := h_s1_bnd_abs (n - 16 * k.val) hn_off_lt + have h_acc_init_n : acc_init.val[16 * k.val + (n - 16 * k.val)]! + = acc_init.val[n]! := by + congr 1; omega + rw [h_acc_init_n] at h_bnd_at_off + exact h_bnd_at_off + · -- Outside the touched window. + have h_outside : n < 16 * k.val ∨ 16 * (k.val + 1) ≤ n := by + by_contra hc + push Not at hc + exact hcase ⟨hc.1, hc.2⟩ + have h_acc1_eq_acc : acc1.val[n]! = acc.val[n]! := + h_acc1_out n hn h_outside + rw [h_acc1_eq_acc] + exact h_acc_bnd_rel n hn + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_fc h_body + show UseCacheFC.step_post myself rhs acc_init k (.done acc) + unfold UseCacheFC.step_post + show (UseCacheFC.inv myself rhs acc_init 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (16#usize : Std.Usize).val → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → ∀ ℓ : Nat, ℓ < 16 → + acc.val[16 * j + ℓ]! = acc_init.val[16 * j + ℓ]!) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + 2^25) := by + refine ⟨?_, ?_, ?_⟩ + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt ℓ hℓ + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt ℓ hℓ; rw [hk_eq]; exact hj_ge + · intro n hn; exact h_acc_bnd_rel n hn + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 4000000 in +/-- L6.3 — `polynomial.PolynomialRingElement.accumulating_ntt_multiply`: + polynomial-level NTT-domain multiply. Wraps L2.8 across all 16 + vector chunks, computing the running sum into a 256-element I32 + accumulator (one degree-2 polynomial multiply per chunk). + + The impl iterates 16 times over + `vectortraitsOperationsInst.accumulating_ntt_multiply`, slicing + the accumulator as `[i*16 .. (i+1)*16]` per iteration and reading + zetas from `polynomial.zeta(64 + 4i + {0,1,2,3})`. + + POST defers algebraic shape to `accumulating_ntt_multiply_poly_post`. + Preconditions: input polys canonical (all coefficients + `natAbs ≤ 3328`), AND each accumulator lane bounded by `2^30` + (propagates to the L2.8 per-chunk PRE — each L2.8 call's `out` + slice is the corresponding 16-lane window into `accumulator`). + + POST adds a relative bound conjunct (`|r[n]| ≤ |accumulator[n]| + + 2^25`) propagating L2.8's relative bound through the 16-iter + chunk loop. Each of the 256 lanes is updated exactly once (one + binomial step per lane), so the per-lane delta is bounded by + a single L2.8 step's growth. Mirrors the inverse-NTT + bound-infra cascade. + + [F*-port: Libcrux_ml_kem.Polynomial.ntt_multiply (Polynomial.fst: + 853-915). WARNING: upstream `lemma_ntt_multiply_chunk_commutes` + (Chunk.fst:1311) is `assume val` — Lean must PROVE the + per-vector wrap (L6.3a sub-unit).] -/ +@[spec] +theorem accumulating_ntt_multiply_poly_fc + (myself rhs : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (accumulator : Std.Array Std.I32 256#usize) + (h_self : ∀ i : Fin 16, ∀ j : Fin 16, + ((myself.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_rhs : ∀ i : Fin 16, ∀ j : Fin 16, + ((rhs.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, (accumulator.val[n.val]!).val.natAbs ≤ 2^30) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply + (vectortraitsOperationsInst := portable_ops_inst) + myself rhs accumulator + ⦃ ⇓ r => ⌜ (∀ n : Fin 256, (r.val[n.val]!).val.natAbs + ≤ (accumulator.val[n.val]!).val.natAbs + 2^25) ∧ + accumulating_ntt_multiply_poly_post + myself rhs accumulator r ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vre]; simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs iter1 acc1) + (β := UseCacheFC.Acc) + accumulator + 0#usize 16#usize + (UseCacheFC.inv myself rhs accumulator) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_⟩ + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _ _ _; trivial + · intro n _; omega) + ?_) + · -- Post entailment: at k = 16, derive the locked POST. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (UseCacheFC.inv myself rhs accumulator 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j : Nat, j < (16#usize : Std.Usize).val → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (r.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (accumulator.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → ∀ ℓ : Nat, ℓ < 16 → + r.val[16 * j + ℓ]! = accumulator.val[16 * j + ℓ]!) + ∧ (∀ n : Nat, n < 256 → + (r.val[n]!).val.natAbs ≤ (accumulator.val[n]!).val.natAbs + 2^25) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + UseCacheFC.inv] using h_inv_holds + obtain ⟨h_done, _h_undone, h_bnd⟩ := h_inv + refine ⟨?_, ?_⟩ + · intro n; exact h_bnd n.val n.isLt + · unfold accumulating_ntt_multiply_poly_post + intro j hj ℓ hℓ + have h16' : (16#usize : Std.Usize).val = 16 := rfl + exact h_done j (by rw [h16']; exact hj) ℓ hℓ + · -- Step entailment. + intro acc k _h_ge h_le hinv + have h_step := + accumulating_ntt_multiply_poly_step_lemma_fc myself rhs accumulator + h_self h_rhs h_acc_bnd acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : UseCacheFC.step_post myself rhs accumulator k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [UseCacheFC.step_post] using hP + · have hP : UseCacheFC.step_post myself rhs accumulator k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [UseCacheFC.step_post] using hP + +/-! ## §L6.3c — Cache-variant polynomial-level Triple statements. + + Polynomial wrappers around L2.8d (`accumulating_ntt_multiply_fill_cache_fc` + and `accumulating_ntt_multiply_use_cache_fc`). The impl loops over the + 16 chunks, dispatching each through the vector-level cache variant. + + Composition pattern (matrix-row reuse): `_fill_cache(A, B, _, _, cache)` sets + the polynomial cache (16 chunks × 8 cache slots each), then multiple + `_use_cache(A', B, _, _, cache)` calls reuse it with different first + operands and the same `B`. This is the matrix `Aᵀ · r` and `A · s` + pattern in L7.1 / L7.2 / L7.3. + + Cache POST predicate composes with the vector-level + `Spec.ntt_multiply_cache_post` chunk-by-chunk: each of the 16 chunks + of `cache.coefficients` stores per-pair `b·zeta` Mont-reduced products + for that chunk's effective zetas at table positions + `64 + 4j + {0,1,2,3}`. -/ + +/-- Polynomial-level cache POST predicate. For each chunk `j ∈ Fin 16` and + each pair `i ∈ Fin 8`: `cache.coefficients[j].elements[i]` stores the + Mont-reduced product `rhs.coefficients[j].elements[2i+1] · zeta_eff_i` + where the four base zetas for chunk `j` are + `Spec.zeta_at (64 + 4j + {0,1,2,3})`. Composes with the vector-level + `Spec.ntt_multiply_cache_post` per chunk. -/ +noncomputable def accumulating_ntt_multiply_poly_cache_post + (rhs cache : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Prop := + ∀ j : Fin 16, ∀ i : Fin 8, + ((cache.coefficients.val[j.val]!).elements.val[i.val]!).val.natAbs ≤ 3328 + ∧ lift_fe_mont ((cache.coefficients.val[j.val]!).elements.val[i.val]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe_mont ((rhs.coefficients.val[j.val]!).elements.val[2 * i.val + 1]!)) + (Spec.effective_zeta_fe i + (Spec.zeta_at (64 + 4 * j.val)) + (Spec.zeta_at (64 + 4 * j.val + 1)) + (Spec.zeta_at (64 + 4 * j.val + 2)) + (Spec.zeta_at (64 + 4 * j.val + 3))) + +namespace FillCacheFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +abbrev Acc := UseCacheFC.Acc +abbrev Poly := UseCacheFC.Poly + +/-- 5-conjunct invariant for the fill_cache loop. -/ +def inv (myself rhs : Poly) (acc_init : Acc) (cache_init : Poly) : + Std.Usize → Acc → Poly → Result Prop := + fun k acc cache => pure ( + (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → ∀ ℓ : Nat, ℓ < 16 → + acc.val[16 * j + ℓ]! = acc_init.val[16 * j + ℓ]!) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + 2^25) + ∧ (∀ j : Nat, j < k.val → + Spec.ntt_multiply_cache_post + (rhs.coefficients.val[j]!) + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j]! + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j + 1]! + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j + 2]! + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j + 3]! + (cache.coefficients.val[j]!)) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + cache.coefficients.val[j]! = cache_init.coefficients.val[j]!)) + +/-- Step-post for `loop_range_spec_usize` over (acc, cache). -/ +def step_post (myself rhs : Poly) (acc_init : Acc) (cache_init : Poly) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc × Poly) (Acc × Poly)) : + Prop := + match r with + | .cont (iter', acc', cache') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv myself rhs acc_init cache_init iter'.start acc' cache').holds + | .done y => (inv myself rhs acc_init cache_init 16#usize y.1 y.2).holds + +end FillCacheFC + +-- Memory hygiene (rule 1 / SKILL §5.7 Idiom 2). Heavy POST predicates and +-- the namespace's `inv` / `step_post` are made locally irreducible across +-- the step lemma + outer Triple so that elaboration of +-- `apply triple_of_ok_fc h_body` (step) and `apply Std.Do.Triple.of_entails_right` +-- (outer) does not whnf-explode through the 5-conjunct invariant body or +-- the nested `∀ i : Fin 8` cache POST. +section L6_3c_fill_irreducible +attribute [local irreducible] Spec.ntt_multiply_cache_post +attribute [local irreducible] accumulating_ntt_multiply_poly_cache_post +attribute [local irreducible] ntt_multiply_base_case_post +attribute [local irreducible] Spec.chunk_reducing_from_i32_array_pure +attribute [local irreducible] lift_chunk_mont +attribute [local irreducible] Spec.ntt_multiply_pure_no_acc +attribute [local irreducible] ntt_multiply_base_case_alg +attribute [local irreducible] Spec.effective_zeta_fe + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for `_fill_cache` polynomial loop. Mirrors + `accumulating_ntt_multiply_poly_step_lemma_fc` (L6.3 base,) + but threads BOTH `acc` and `cache` through the ControlFlow. -/ +theorem accumulating_ntt_multiply_fill_cache_poly_step_lemma_fc + (myself rhs : FillCacheFC.Poly) (acc_init : FillCacheFC.Acc) + (cache_init : FillCacheFC.Poly) + (h_self : ∀ i : Fin 16, ∀ j : Fin 16, + ((myself.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_rhs : ∀ i : Fin 16, ∀ j : Fin 16, + ((rhs.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, (acc_init.val[n.val]!).val.natAbs ≤ 2^30) + (acc : FillCacheFC.Acc) (cache : FillCacheFC.Poly) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (FillCacheFC.inv myself rhs acc_init cache_init k acc cache).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs + { start := k, «end» := 16#usize } acc cache + ⦃ ⇓ r => ⌜ FillCacheFC.step_post myself rhs acc_init cache_init k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_acc_len : acc.val.length = 256 := Std.Array.length_eq acc + have h_acc_init_len : acc_init.val.length = 256 := Std.Array.length_eq acc_init + have h_self_coef_len : myself.coefficients.length = 16 := Std.Array.length_eq _ + have h_rhs_coef_len : rhs.coefficients.length = 16 := Std.Array.length_eq _ + have h_cache_coef_len : cache.coefficients.length = 16 := Std.Array.length_eq _ + have h_cache_init_coef_len : cache_init.coefficients.length = 16 := Std.Array.length_eq _ + obtain ⟨h_acc_done, h_acc_undone, h_acc_bnd_rel, h_cache_done, h_cache_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some i = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s_iter, hs_val_eq, h_iter_some⟩ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) t := self.coefficients[k] and t1 := rhs.coefficients[k]. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + myself.coefficients.val[k.val]! with ht_def + set t1 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + rhs.coefficients.val[k.val]! with ht1_def + have h_idx_t : Aeneas.Std.Array.index_usize myself.coefficients k = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq myself.coefficients k + (by rw [h_self_coef_len]; exact hk_16) + have h_idx_t1 : Aeneas.Std.Array.index_usize rhs.coefficients k = .ok t1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq rhs.coefficients k + (by rw [h_rhs_coef_len]; exact hk_16) + -- (2) i1 := k * 16, i2 := k + 1, i3 := i2 * 16. + have hi1_max : k.val * (16#usize : Std.Usize).val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hum] + have h1 : k.val * 16 ≤ 15 * 16 := Nat.mul_le_mul_right 16 hk_15 + have : (15 * 16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i1, hi1_eq, hi1_val⟩ := usize_mul_ok_eq_fc k 16#usize hi1_max + have hi1_val_eq : i1.val = 16 * k.val := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hi1_val, hum]; omega + have hi2_max : k.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hum] + have : (16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i2, hi2_eq, hi2_val⟩ := usize_add_ok_eq_fc k 1#usize hi2_max + have hi2_val_eq : i2.val = k.val + 1 := by + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hi2_val, hum] + have hi3_max : i2.val * (16#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hum, hi2_val_eq] + have : k.val + 1 ≤ 16 := by omega + have h1 : (k.val + 1) * 16 ≤ 16 * 16 := Nat.mul_le_mul_right 16 this + have : (16 * 16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i3, hi3_eq, hi3_val⟩ := usize_mul_ok_eq_fc i2 16#usize hi3_max + have hi3_val_eq : i3.val = 16 * (k.val + 1) := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hi3_val, hi2_val_eq, hum]; omega + -- (3) Sub-slice via Array index_mut RangeUsize. + have h0_le : i1.val ≤ i3.val := by rw [hi1_val_eq, hi3_val_eq]; omega + have hi3_le : i3.val ≤ acc.val.length := by + rw [h_acc_len, hi3_val_eq] + have : k.val + 1 ≤ 16 := by omega + have h1 : 16 * (k.val + 1) ≤ 16 * 16 := Nat.mul_le_mul_left _ this + omega + obtain ⟨s, back, h_imt_eq, h_s_val_eq, h_s_len_eq, h_back_eq⟩ := + array_index_mut_range_ok_eq_fc acc + ({ start := i1, «end» := i3 } : CoreModels.core.ops.range.Range Std.Usize) + h0_le hi3_le + have h_s_len16 : s.length = 16 := by + show s.val.length = 16 + rw [h_s_len_eq] + show i3.val - i1.val = 16 + rw [hi3_val_eq, hi1_val_eq]; omega + have h_s_lane : ∀ ℓ : Nat, ℓ < 16 → + s.val[ℓ]! = acc.val[16 * k.val + ℓ]! := by + intro ℓ hℓ + rw [h_s_val_eq] + have h_idx_lt : i1.val + ℓ < i3.val := by + rw [hi1_val_eq, hi3_val_eq]; omega + have h_bnd : i3.val ≤ acc.val.length ∧ i1.val + ℓ < i3.val := ⟨hi3_le, h_idx_lt⟩ + rw [List.getElem!_slice i1.val i3.val ℓ acc.val h_bnd] + rw [hi1_val_eq] + have h_s_lane_init : ∀ ℓ : Nat, ℓ < 16 → + s.val[ℓ]! = acc_init.val[16 * k.val + ℓ]! := by + intro ℓ hℓ + rw [h_s_lane ℓ hℓ] + exact h_acc_undone k.val (Nat.le_refl _) hk_16 ℓ hℓ + -- (4) Per-lane bound on s (≤ 2^30 from h_acc_bnd). + have h_s_bnd : ∀ k' : Fin 16, (s.val[k'.val]!).val.natAbs ≤ 2^30 := by + intro k' + rw [h_s_lane_init k'.val k'.isLt] + have h_lt : 16 * k.val + k'.val < 256 := by + have : k.val ≤ 15 := by omega + have hk' : k'.val < 16 := k'.isLt + have : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + exact h_acc_bnd ⟨16 * k.val + k'.val, h_lt⟩ + -- (3') Cache-chunk extract via `Array.index_mut_usize cache.coefficients k`. + set t2 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + cache.coefficients.val[k.val]! with ht2_def + have h_imt_cache : Aeneas.Std.Array.index_mut_usize cache.coefficients k + = .ok (t2, cache.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + have h_idx : Aeneas.Std.Array.index_usize cache.coefficients k + = .ok (cache.coefficients.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq cache.coefficients k + (by rw [h_cache_coef_len]; exact hk_16) + rw [h_idx]; rfl + -- (5) Read 4 zetas via polynomial.zeta_fc. + have hi4_max : (4#usize : Std.Usize).val * k.val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (4#usize : Std.Usize).val = 4 := rfl + rw [hum] + have : 4 * k.val ≤ 4 * 15 := Nat.mul_le_mul_left 4 hk_15 + have : (4 * 15 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i4, hi4_eq, hi4_val⟩ := usize_mul_ok_eq_fc 4#usize k hi4_max + have hi4_val_eq : i4.val = 4 * k.val := by + have hum : (4#usize : Std.Usize).val = 4 := rfl + rw [hi4_val, hum] + have hi5_max : (64#usize : Std.Usize).val + i4.val ≤ Std.Usize.max := by + have hum : (64#usize : Std.Usize).val = 64 := rfl + rw [hum, hi4_val_eq] + have hk_15 : k.val ≤ 15 := by omega + have : 4 * k.val ≤ 4 * 15 := Nat.mul_le_mul_left 4 hk_15 + have : (64 + 4 * 15 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i5, hi5_eq, hi5_val⟩ := usize_add_ok_eq_fc 64#usize i4 hi5_max + have hi5_val_eq : i5.val = 64 + 4 * k.val := by + have hum : (64#usize : Std.Usize).val = 64 := rfl + rw [hi5_val, hum, hi4_val_eq] + have hi5_lt_128 : i5.val < 128 := by rw [hi5_val_eq]; omega + obtain ⟨z0, hz0_eq, hz0_post⟩ := + triple_exists_ok_fc (polynomial.zeta_fc i5 hi5_lt_128) + obtain ⟨hz0_val_eq, hz0_bnd, hz0_lift⟩ := hz0_post + have hi8_max : i5.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hum, hi5_val_eq] + have hk_15 : k.val ≤ 15 := by omega + have : (64 + 4 * 15 + 1 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i8, hi8_eq, hi8_val⟩ := usize_add_ok_eq_fc i5 1#usize hi8_max + have hi8_val_eq : i8.val = 64 + 4 * k.val + 1 := by + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hi8_val, hi5_val_eq, hum] + have hi8_lt_128 : i8.val < 128 := by rw [hi8_val_eq]; omega + obtain ⟨z1, hz1_eq, hz1_post⟩ := + triple_exists_ok_fc (polynomial.zeta_fc i8 hi8_lt_128) + obtain ⟨hz1_val_eq, hz1_bnd, hz1_lift⟩ := hz1_post + have hi11_max : i5.val + (2#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (2#usize : Std.Usize).val = 2 := rfl + rw [hum, hi5_val_eq] + have hk_15 : k.val ≤ 15 := by omega + have : (64 + 4 * 15 + 2 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i11, hi11_eq, hi11_val⟩ := usize_add_ok_eq_fc i5 2#usize hi11_max + have hi11_val_eq : i11.val = 64 + 4 * k.val + 2 := by + have hum : (2#usize : Std.Usize).val = 2 := rfl + rw [hi11_val, hi5_val_eq, hum] + have hi11_lt_128 : i11.val < 128 := by rw [hi11_val_eq]; omega + obtain ⟨z2, hz2_eq, hz2_post⟩ := + triple_exists_ok_fc (polynomial.zeta_fc i11 hi11_lt_128) + obtain ⟨hz2_val_eq, hz2_bnd, hz2_lift⟩ := hz2_post + have hi14_max : i5.val + (3#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (3#usize : Std.Usize).val = 3 := rfl + rw [hum, hi5_val_eq] + have hk_15 : k.val ≤ 15 := by omega + have : (64 + 4 * 15 + 3 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i14, hi14_eq, hi14_val⟩ := usize_add_ok_eq_fc i5 3#usize hi14_max + have hi14_val_eq : i14.val = 64 + 4 * k.val + 3 := by + have hum : (3#usize : Std.Usize).val = 3 := rfl + rw [hi14_val, hi5_val_eq, hum] + have hi14_lt_128 : i14.val < 128 := by rw [hi14_val_eq]; omega + obtain ⟨z3, hz3_eq, hz3_post⟩ := + triple_exists_ok_fc (polynomial.zeta_fc i14 hi14_lt_128) + obtain ⟨hz3_val_eq, hz3_bnd, hz3_lift⟩ := hz3_post + -- (6) Apply L2.8d to get (s1, cache_chunk1). + have h_t_lhs : ∀ j : Fin 16, (t.elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro j; exact h_self ⟨k.val, hk_16⟩ j + have h_t1_rhs : ∀ j : Fin 16, (t1.elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro j; exact h_rhs ⟨k.val, hk_16⟩ j + obtain ⟨p_pair, h_p_eq, h_s1_len, h_s1_bnd, h_s1_post, h_cache_chunk_post, h_cache_chunk_unc⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_fill_cache_fc t t1 s t2 z0 z1 z2 z3 h_s_len16 + h_t_lhs h_t1_rhs hz0_bnd hz1_bnd hz2_bnd hz3_bnd h_s_bnd) + set s1 : Aeneas.Std.Slice Std.I32 := p_pair.1 with hs1_def + set cache_chunk1 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + p_pair.2 with hcc1_def + -- s1's bound vs s lanes. + have h_s1_bnd_abs : ∀ k' : Nat, k' < 16 → + (s1.val[k']!).val.natAbs ≤ (acc_init.val[16 * k.val + k']!).val.natAbs + 2^25 := by + intro k' hk' + have h_step_bnd := h_s1_bnd ⟨k', hk'⟩ + simp only at h_step_bnd + rw [h_s_lane_init k' hk'] at h_step_bnd + exact h_step_bnd + -- (7) Compose acc1 := back s1. + set acc1 : FillCacheFC.Acc := back s1 with hacc1_def + have h_acc1_val : acc1.val = acc.val.setSlice! i1.val s1.val := + h_back_eq s1 (by show s1.val.length = i3.val - i1.val; rw [← h_s_len_eq]; + show s1.length = s.length; rw [h_s1_len, h_s_len16]) + have h_acc1_len : acc1.val.length = 256 := by + rw [h_acc1_val, List.length_setSlice!, h_acc_len] + have h_acc1_in : ∀ ℓ : Nat, ℓ < 16 → + acc1.val[16 * k.val + ℓ]! = s1.val[ℓ]! := by + intro ℓ hℓ + rw [h_acc1_val] + have h_get : (acc.val.setSlice! i1.val s1.val)[16 * k.val + ℓ]! + = s1.val[(16 * k.val + ℓ) - i1.val]! := by + apply List.getElem!_setSlice!_middle + refine ⟨?_, ?_, ?_⟩ + · rw [hi1_val_eq]; omega + · rw [hi1_val_eq] + have h_sub' : 16 * k.val + ℓ - 16 * k.val = ℓ := by omega + rw [h_sub'] + show ℓ < s1.val.length + have h_s1' : s1.val.length = 16 := h_s1_len + rw [h_s1']; exact hℓ + · rw [h_acc_len] + have hk_15' : k.val ≤ 15 := by omega + have h1 : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 hk_15' + omega + rw [h_get] + have h_sub : (16 * k.val + ℓ) - i1.val = ℓ := by + rw [hi1_val_eq]; omega + rw [h_sub] + have h_acc1_out : ∀ n : Nat, n < 256 → + (n < 16 * k.val ∨ 16 * (k.val + 1) ≤ n) → + acc1.val[n]! = acc.val[n]! := by + intro n hn hcases + rw [h_acc1_val] + rcases hcases with hlt | hge + · apply List.getElem!_setSlice!_prefix + rw [hi1_val_eq]; exact hlt + · apply List.getElem!_setSlice!_suffix + rw [hi1_val_eq] + have h_s1' : s1.val.length = 16 := h_s1_len + rw [h_s1'] + have h_eq16 : 16 * k.val + 16 = 16 * (k.val + 1) := by ring + rw [h_eq16]; exact hge + -- (7') Compose cache1 := { coefficients := cache.coefficients.set k cache_chunk1 }. + set cache1 : FillCacheFC.Poly := + { coefficients := cache.coefficients.set k cache_chunk1 } with hcache1_def + have h_cache1_at : cache1.coefficients.val[k.val]! = cache_chunk1 := by + show (cache.coefficients.set k cache_chunk1).val[k.val]! = cache_chunk1 + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq cache.coefficients k k.val cache_chunk1 + ⟨rfl, by rw [h_cache_coef_len]; exact hk_16⟩ + have h_cache1_ne : ∀ j : Nat, j ≠ k.val → + cache1.coefficients.val[j]! = cache.coefficients.val[j]! := by + intro j hj + show (cache.coefficients.set k cache_chunk1).val[j]! = cache.coefficients.val[j]! + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne cache.coefficients k j cache_chunk1 + (fun h => hj h.symm) + -- (9) Body equation. + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs + { start := k, «end» := 16#usize } acc cache + = .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc1, cache1)) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let t' ← Aeneas.Std.Array.index_usize myself.coefficients k + let t1' ← Aeneas.Std.Array.index_usize rhs.coefficients k + let i1' ← (k * 16#usize : Result Std.Usize) + let i2' ← k + 1#usize + let i3' ← i2' * 16#usize + let (s', index_mut_back) ← + core.Array.Insts.CoreOpsIndexIndexMut.index_mut + (core.Slice.Insts.CoreOpsIndexIndexMut + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + Std.I32)) acc { start := i1', «end» := i3' } + let (t2', index_mut_back1) ← Aeneas.Std.Array.index_mut_usize cache.coefficients k + let i4' ← 4#usize * k + let i5' ← 64#usize + i4' + let i6' ← libcrux_iot_ml_kem.polynomial.zeta i5' + let i7' ← 64#usize + i4' + let i8' ← i7' + 1#usize + let i9' ← libcrux_iot_ml_kem.polynomial.zeta i8' + let i10' ← 64#usize + i4' + let i11' ← i10' + 2#usize + let i12' ← libcrux_iot_ml_kem.polynomial.zeta i11' + let i13' ← 64#usize + i4' + let i14' ← i13' + 3#usize + let i15' ← libcrux_iot_ml_kem.polynomial.zeta i14' + let p ← + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.accumulating_ntt_multiply_fill_cache + t' t1' s' t2' i6' i9' i12' i15' + .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), index_mut_back p.1, + ({ coefficients := index_mut_back1 p.2 } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + : Result _) = _ + rw [h_idx_t]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_t1]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_cache]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi5_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hz0_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi8_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hz1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi11_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hz2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi14_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hz3_eq]; simp only [Aeneas.Std.bind_tc_ok] + show ((do + let p ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_fill_cache + t t1 s t2 z0 z1 z2 z3 + .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), back p.1, + ({ coefficients := cache.coefficients.set k p.2 } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + : Result _) = _ + rw [h_p_eq] + rfl + apply triple_of_ok_fc h_body + show FillCacheFC.step_post myself rhs acc_init cache_init k + (.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc1, cache1)) + unfold FillCacheFC.step_post + refine ⟨h_lt, rfl, hs_val_eq, ?_⟩ + show (FillCacheFC.inv myself rhs acc_init cache_init s_iter acc1 cache1).holds + unfold FillCacheFC.inv + -- Build the five invariant conjuncts. + have h_inv_pure : + (∀ j : Nat, j < s_iter.val → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + ∧ (∀ j : Nat, s_iter.val ≤ j → j < 16 → ∀ ℓ : Nat, ℓ < 16 → + acc1.val[16 * j + ℓ]! = acc_init.val[16 * j + ℓ]!) + ∧ (∀ n : Nat, n < 256 → + (acc1.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + 2^25) + ∧ (∀ j : Nat, j < s_iter.val → + Spec.ntt_multiply_cache_post + (rhs.coefficients.val[j]!) + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j]! + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j + 1]! + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j + 2]! + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j + 3]! + (cache1.coefficients.val[j]!)) + ∧ (∀ j : Nat, s_iter.val ≤ j → j < 16 → + cache1.coefficients.val[j]! = cache_init.coefficients.val[j]!) := by + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + · -- (a) Touched-chunk FC. + intro j hj ℓ hℓ + rw [hs_val_eq] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · have h_in_range : 16 * j + ℓ < 16 * k.val := by + have h1 : 16 * j + 16 ≤ 16 * k.val := by + have : j + 1 ≤ k.val := by omega + have : 16 * (j + 1) ≤ 16 * k.val := Nat.mul_le_mul_left 16 this + omega + omega + have h_lt_256 : 16 * j + ℓ < 256 := by + have : k.val ≤ 15 := by omega + have : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + have h_acc1_eq_acc : acc1.val[16 * j + ℓ]! = acc.val[16 * j + ℓ]! := + h_acc1_out (16 * j + ℓ) h_lt_256 (Or.inl h_in_range) + rw [h_acc1_eq_acc] + exact h_acc_done j hj_lt_k ℓ hℓ + · subst hj_eq_k + rw [h_acc1_in ℓ hℓ] + unfold ntt_multiply_base_case_post at h_s1_post + have h_lhs_val_eq : + (Spec.chunk_reducing_from_i32_array_pure s1).val[ℓ]! + = Spec.mont_reduce_pure (lift_fe_int (s1.val[ℓ]!).val) := + Spec.chunk_reducing_from_i32_array_pure_lane_eq s1 ℓ hℓ + have h_s_chunk_val : + (Spec.chunk_reducing_from_i32_array_pure s).val[ℓ]! + = Spec.mont_reduce_pure (lift_fe_int (s.val[ℓ]!).val) := + Spec.chunk_reducing_from_i32_array_pure_lane_eq s ℓ hℓ + have h_rhs_val_eq : + (ntt_multiply_base_case_alg + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3) + (Spec.chunk_reducing_from_i32_array_pure s)).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.chunk_reducing_from_i32_array_pure s).val[ℓ]!) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3)).val[ℓ]!) := by + unfold ntt_multiply_base_case_alg + exact Spec.chunk_add_pure_lane_eq _ _ ℓ hℓ + have h_chunk_eq : + Spec.mont_reduce_pure (lift_fe_int (s1.val[ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (s.val[ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3)).val[ℓ]!) := by + have h_eq : (Spec.chunk_reducing_from_i32_array_pure s1).val[ℓ]! + = (ntt_multiply_base_case_alg + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3) + (Spec.chunk_reducing_from_i32_array_pure s)).val[ℓ]! := by + rw [h_s1_post] + rw [h_lhs_val_eq] at h_eq + rw [h_rhs_val_eq] at h_eq + rw [h_s_chunk_val] at h_eq + exact h_eq + rw [h_chunk_eq] + rw [h_s_lane_init ℓ hℓ] + rw [hz0_lift, hz1_lift, hz2_lift, hz3_lift] + rw [hi5_val_eq, hi8_val_eq, hi11_val_eq, hi14_val_eq] + · -- (b) Untouched acc chunks. + intro j hj_ge hj_lt ℓ hℓ + rw [hs_val_eq] at hj_ge + have h_n_lt_256 : 16 * j + ℓ < 256 := by + have : j ≤ 15 := by omega + have : 16 * j ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + have h_ge_range : 16 * (k.val + 1) ≤ 16 * j + ℓ := by + have : k.val + 1 ≤ j := hj_ge + have : 16 * (k.val + 1) ≤ 16 * j := Nat.mul_le_mul_left 16 this + omega + have h_acc1_eq_acc : acc1.val[16 * j + ℓ]! = acc.val[16 * j + ℓ]! := + h_acc1_out (16 * j + ℓ) h_n_lt_256 (Or.inr h_ge_range) + rw [h_acc1_eq_acc] + exact h_acc_undone j (by omega) hj_lt ℓ hℓ + · -- (c) Universal acc bound. + intro n hn + by_cases hcase : 16 * k.val ≤ n ∧ n < 16 * (k.val + 1) + · obtain ⟨hge, hlt⟩ := hcase + have hn_decomp : n = 16 * k.val + (n - 16 * k.val) := by omega + have hn_off_lt : n - 16 * k.val < 16 := by omega + have h_acc1_n : acc1.val[n]! = s1.val[n - 16 * k.val]! := by + conv_lhs => rw [hn_decomp] + exact h_acc1_in (n - 16 * k.val) hn_off_lt + rw [h_acc1_n] + have h_bnd_at_off := h_s1_bnd_abs (n - 16 * k.val) hn_off_lt + have h_acc_init_n : acc_init.val[16 * k.val + (n - 16 * k.val)]! + = acc_init.val[n]! := by + congr 1; omega + rw [h_acc_init_n] at h_bnd_at_off + exact h_bnd_at_off + · have h_outside : n < 16 * k.val ∨ 16 * (k.val + 1) ≤ n := by + by_contra hc + push Not at hc + exact hcase ⟨hc.1, hc.2⟩ + have h_acc1_eq_acc : acc1.val[n]! = acc.val[n]! := + h_acc1_out n hn h_outside + rw [h_acc1_eq_acc] + exact h_acc_bnd_rel n hn + · -- (d) Cache touched chunks. + intro j hj + rw [hs_val_eq] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: cache1[j] = cache[j], use h_cache_done. + have hj_ne : j ≠ k.val := by omega + rw [h_cache1_ne j hj_ne] + exact h_cache_done j hj_lt_k + · -- j = k.val: cache1[k.val] = cache_chunk1, use h_cache_chunk_post. + subst hj_eq_k + rw [h_cache1_at] + -- h_cache_chunk_post : Spec.ntt_multiply_cache_post t1 z0 z1 z2 z3 cache_chunk1. + -- z0 = ZETAS_TIMES_MONTGOMERY_R[i5.val]! and i5.val = 64 + 4*k.val, etc. + have hz0_id : z0 = libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * k.val]! := by + rw [hz0_val_eq, hi5_val_eq] + have hz1_id : z1 = libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * k.val + 1]! := by + rw [hz1_val_eq, hi8_val_eq] + have hz2_id : z2 = libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * k.val + 2]! := by + rw [hz2_val_eq, hi11_val_eq] + have hz3_id : z3 = libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * k.val + 3]! := by + rw [hz3_val_eq, hi14_val_eq] + rw [← hz0_id, ← hz1_id, ← hz2_id, ← hz3_id] + exact h_cache_chunk_post + · -- (e) Cache untouched. + intro j hj_ge hj_lt + rw [hs_val_eq] at hj_ge + have hj_ne : j ≠ k.val := by omega + rw [h_cache1_ne j hj_ne] + exact h_cache_undone j (by omega) hj_lt + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs + { start := k, «end» := 16#usize } acc cache + = .ok (ControlFlow.done (acc, cache)) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_fc h_body + show FillCacheFC.step_post myself rhs acc_init cache_init k (.done (acc, cache)) + unfold FillCacheFC.step_post + show (FillCacheFC.inv myself rhs acc_init cache_init 16#usize acc cache).holds + unfold FillCacheFC.inv + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (16#usize : Std.Usize).val → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → ∀ ℓ : Nat, ℓ < 16 → + acc.val[16 * j + ℓ]! = acc_init.val[16 * j + ℓ]!) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + 2^25) + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → + Spec.ntt_multiply_cache_post + (rhs.coefficients.val[j]!) + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j]! + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j + 1]! + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j + 2]! + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * j + 3]! + (cache.coefficients.val[j]!)) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + cache.coefficients.val[j]! = cache_init.coefficients.val[j]!) := by + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt ℓ hℓ + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt ℓ hℓ; rw [hk_eq]; exact hj_ge + · intro n hn; exact h_acc_bnd_rel n hn + · intro j hj; rw [h16] at hj + apply h_cache_done j _; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_cache_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +/-- L6.3c — `polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache`: + polynomial wrapper of `accumulating_ntt_multiply_fill_cache_fc`. Loops + over the 16 chunks; per chunk j it dispatches the L2.8d + `_fill_cache` variant with chunk `j`'s zetas + (`polynomial.zeta (64+4j+{0,1,2,3})`) and the chunk's mutable cache slot + (`cache.coefficients[j]`). + + POST shape mirrors L6.3 (length + relative accumulator bound + + `accumulating_ntt_multiply_poly_post`) AND adds a polynomial-level + cache POST (`accumulating_ntt_multiply_poly_cache_post`) asserting + that each cache chunk stores the per-pair Mont-reduced + `b·zeta` products for its effective zetas. -/ +@[spec] +theorem accumulating_ntt_multiply_fill_cache_poly_fc + (myself rhs cache : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (accumulator : Std.Array Std.I32 256#usize) + (h_self : ∀ i : Fin 16, ∀ j : Fin 16, + ((myself.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_rhs : ∀ i : Fin 16, ∀ j : Fin 16, + ((rhs.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, (accumulator.val[n.val]!).val.natAbs ≤ 2^30) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache + (vectortraitsOperationsInst := portable_ops_inst) + myself rhs accumulator cache + ⦃ ⇓ p => ⌜ (∀ n : Fin 256, (p.1.val[n.val]!).val.natAbs + ≤ (accumulator.val[n.val]!).val.natAbs + 2^25) ∧ + accumulating_ntt_multiply_poly_post + myself rhs accumulator p.1 ∧ + accumulating_ntt_multiply_poly_cache_post rhs p.2 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vre]; simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, p) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_fill_cache_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs iter1 p.1 p.2) + (β := FillCacheFC.Acc × FillCacheFC.Poly) + (accumulator, cache) + 0#usize 16#usize + (fun k p => FillCacheFC.inv myself rhs accumulator cache k p.1 p.2) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _ _ _; trivial + · intro n _; omega + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _; trivial) + ?_) + · -- Post entailment at k = 16: derive the locked POST. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (FillCacheFC.inv myself rhs accumulator cache 16#usize r.1 r.2).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + obtain ⟨h_done, _h_undone, h_bnd, h_cache_done, _h_cache_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_holds + refine ⟨?_, ?_, ?_⟩ + · intro n; exact h_bnd n.val n.isLt + · unfold accumulating_ntt_multiply_poly_post + intro j hj ℓ hℓ + have h16' : (16#usize : Std.Usize).val = 16 := rfl + exact h_done j (by rw [h16']; exact hj) ℓ hℓ + · -- accumulating_ntt_multiply_poly_cache_post: bridge chunk-level cache POST to poly-level. + show accumulating_ntt_multiply_poly_cache_post rhs r.2 + unfold accumulating_ntt_multiply_poly_cache_post + intro j_fin i_fin + have h16' : (16#usize : Std.Usize).val = 16 := rfl + have h_chunk := h_cache_done j_fin.val (by rw [h16']; exact j_fin.isLt) + -- h_chunk : Spec.ntt_multiply_cache_post (rhs.coefficients[j]!) ZETAS[64+4j+0..3]! (r.2.coefficients[j]!) + unfold Spec.ntt_multiply_cache_post at h_chunk + exact h_chunk i_fin + · -- Step entailment. + intro p k _h_ge h_le hinv + have h_step := accumulating_ntt_multiply_fill_cache_poly_step_lemma_fc + myself rhs accumulator cache h_self h_rhs h_acc_bnd p.1 p.2 k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc_cache⟩ | y + · have hP : FillCacheFC.step_post myself rhs accumulator cache k + (.cont (iter', acc_cache.1, acc_cache.2)) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [FillCacheFC.step_post] using hP + · have hP : FillCacheFC.step_post myself rhs accumulator cache k (.done (y.1, y.2)) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [FillCacheFC.step_post] using hP + +end L6_3c_fill_irreducible + +-- Memory hygiene (rule 1 / SKILL §5.7 Idiom 2). Same irreducible-attribute +-- discipline as the L6.3c.fill section, applied to the read-only `_use_cache` +-- step lemma + outer Triple. We REUSE `UseCacheFC.inv` / `UseCacheFC.step_post` +-- (cache is closure-captured, not threaded), so neither is added to the +-- irreducible list (this would break the `simpa` destructure of `h_inv`). +section L6_3c_use_irreducible +attribute [local irreducible] Spec.ntt_multiply_cache_post +attribute [local irreducible] accumulating_ntt_multiply_poly_cache_post +attribute [local irreducible] ntt_multiply_base_case_post +attribute [local irreducible] Spec.chunk_reducing_from_i32_array_pure +attribute [local irreducible] lift_chunk_mont +attribute [local irreducible] Spec.ntt_multiply_pure_no_acc +attribute [local irreducible] ntt_multiply_base_case_alg +attribute [local irreducible] Spec.effective_zeta_fe + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for `_use_cache` polynomial loop. Mirrors + `accumulating_ntt_multiply_poly_step_lemma_fc` (L6.3 base) but accepts + a read-only `cache` parameter (closure-captured in the body) together + with the polynomial-level cache PRE `h_cache`. Cache is unchanged, so + the carrier is `UseCacheFC.Acc` and the invariant is the L6.3 base 3-tuple. -/ +theorem accumulating_ntt_multiply_use_cache_poly_step_lemma_fc + (myself rhs : UseCacheFC.Poly) (cache : UseCacheFC.Poly) + (acc_init : UseCacheFC.Acc) + (h_self : ∀ i : Fin 16, ∀ j : Fin 16, + ((myself.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_rhs : ∀ i : Fin 16, ∀ j : Fin 16, + ((rhs.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, (acc_init.val[n.val]!).val.natAbs ≤ 2^30) + (h_cache : accumulating_ntt_multiply_poly_cache_post rhs cache) + (acc : UseCacheFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (UseCacheFC.inv myself rhs acc_init k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs cache + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ UseCacheFC.step_post myself rhs acc_init k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_acc_len : acc.val.length = 256 := Std.Array.length_eq acc + have h_acc_init_len : acc_init.val.length = 256 := Std.Array.length_eq acc_init + have h_self_coef_len : myself.coefficients.length = 16 := Std.Array.length_eq _ + have h_rhs_coef_len : rhs.coefficients.length = 16 := Std.Array.length_eq _ + have h_cache_coef_len : cache.coefficients.length = 16 := Std.Array.length_eq _ + obtain ⟨h_acc_done, h_acc_undone, h_acc_bnd_rel⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some i = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s_iter, hs_val_eq, h_iter_some⟩ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) t := self.coefficients[k] and t1 := rhs.coefficients[k]. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + myself.coefficients.val[k.val]! with ht_def + set t1 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + rhs.coefficients.val[k.val]! with ht1_def + have h_idx_t : Aeneas.Std.Array.index_usize myself.coefficients k = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq myself.coefficients k + (by rw [h_self_coef_len]; exact hk_16) + have h_idx_t1 : Aeneas.Std.Array.index_usize rhs.coefficients k = .ok t1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq rhs.coefficients k + (by rw [h_rhs_coef_len]; exact hk_16) + -- (2) i1 := k * 16, i2 := k + 1, i3 := i2 * 16. + have hi1_max : k.val * (16#usize : Std.Usize).val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hum] + have h1 : k.val * 16 ≤ 15 * 16 := Nat.mul_le_mul_right 16 hk_15 + have : (15 * 16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i1, hi1_eq, hi1_val⟩ := usize_mul_ok_eq_fc k 16#usize hi1_max + have hi1_val_eq : i1.val = 16 * k.val := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hi1_val, hum]; omega + have hi2_max : k.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hum] + have : (16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i2, hi2_eq, hi2_val⟩ := usize_add_ok_eq_fc k 1#usize hi2_max + have hi2_val_eq : i2.val = k.val + 1 := by + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hi2_val, hum] + have hi3_max : i2.val * (16#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hum, hi2_val_eq] + have : k.val + 1 ≤ 16 := by omega + have h1 : (k.val + 1) * 16 ≤ 16 * 16 := Nat.mul_le_mul_right 16 this + have : (16 * 16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i3, hi3_eq, hi3_val⟩ := usize_mul_ok_eq_fc i2 16#usize hi3_max + have hi3_val_eq : i3.val = 16 * (k.val + 1) := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hi3_val, hi2_val_eq, hum]; omega + -- (3) Sub-slice via Array index_mut RangeUsize. + have h0_le : i1.val ≤ i3.val := by rw [hi1_val_eq, hi3_val_eq]; omega + have hi3_le : i3.val ≤ acc.val.length := by + rw [h_acc_len, hi3_val_eq] + have : k.val + 1 ≤ 16 := by omega + have h1 : 16 * (k.val + 1) ≤ 16 * 16 := Nat.mul_le_mul_left _ this + omega + obtain ⟨s, back, h_imt_eq, h_s_val_eq, h_s_len_eq, h_back_eq⟩ := + array_index_mut_range_ok_eq_fc acc + ({ start := i1, «end» := i3 } : CoreModels.core.ops.range.Range Std.Usize) + h0_le hi3_le + have h_s_len16 : s.length = 16 := by + show s.val.length = 16 + rw [h_s_len_eq] + show i3.val - i1.val = 16 + rw [hi3_val_eq, hi1_val_eq]; omega + have h_s_lane : ∀ ℓ : Nat, ℓ < 16 → + s.val[ℓ]! = acc.val[16 * k.val + ℓ]! := by + intro ℓ hℓ + rw [h_s_val_eq] + have h_idx_lt : i1.val + ℓ < i3.val := by + rw [hi1_val_eq, hi3_val_eq]; omega + have h_bnd : i3.val ≤ acc.val.length ∧ i1.val + ℓ < i3.val := ⟨hi3_le, h_idx_lt⟩ + rw [List.getElem!_slice i1.val i3.val ℓ acc.val h_bnd] + rw [hi1_val_eq] + have h_s_lane_init : ∀ ℓ : Nat, ℓ < 16 → + s.val[ℓ]! = acc_init.val[16 * k.val + ℓ]! := by + intro ℓ hℓ + rw [h_s_lane ℓ hℓ] + exact h_acc_undone k.val (Nat.le_refl _) hk_16 ℓ hℓ + -- (4) Per-lane bound on s (≤ 2^30 from h_acc_bnd). + have h_s_bnd : ∀ k' : Fin 16, (s.val[k'.val]!).val.natAbs ≤ 2^30 := by + intro k' + rw [h_s_lane_init k'.val k'.isLt] + have h_lt : 16 * k.val + k'.val < 256 := by + have : k.val ≤ 15 := by omega + have hk' : k'.val < 16 := k'.isLt + have : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + exact h_acc_bnd ⟨16 * k.val + k'.val, h_lt⟩ + -- (3') Cache-chunk extract via `Array.index_usize cache.coefficients k` + -- (single-element READ — `_use_cache` does not mutate the cache). + set t2 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + cache.coefficients.val[k.val]! with ht2_def + have h_idx_t2 : Aeneas.Std.Array.index_usize cache.coefficients k = .ok t2 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq cache.coefficients k + (by rw [h_cache_coef_len]; exact hk_16) + -- (5) The four zetas: derived from the cache PRE at chunk k, without + -- calling `polynomial.zeta_fc` (the impl does NOT read zetas here). + set z0 : Std.I16 := + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * k.val]! with hz0_def + set z1 : Std.I16 := + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * k.val + 1]! with hz1_def + set z2 : Std.I16 := + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * k.val + 2]! with hz2_def + set z3 : Std.I16 := + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[64 + 4 * k.val + 3]! with hz3_def + have hz0_bnd : z0.val.natAbs ≤ 1664 := ZETAS_bound (64 + 4 * k.val) (by omega) + have hz1_bnd : z1.val.natAbs ≤ 1664 := ZETAS_bound (64 + 4 * k.val + 1) (by omega) + have hz2_bnd : z2.val.natAbs ≤ 1664 := ZETAS_bound (64 + 4 * k.val + 2) (by omega) + have hz3_bnd : z3.val.natAbs ≤ 1664 := ZETAS_bound (64 + 4 * k.val + 3) (by omega) + have hz0_lift : lift_fe_mont z0 = Spec.zeta_at (64 + 4 * k.val) := rfl + have hz1_lift : lift_fe_mont z1 = Spec.zeta_at (64 + 4 * k.val + 1) := rfl + have hz2_lift : lift_fe_mont z2 = Spec.zeta_at (64 + 4 * k.val + 2) := rfl + have hz3_lift : lift_fe_mont z3 = Spec.zeta_at (64 + 4 * k.val + 3) := rfl + -- (6) Derive the per-chunk vector-level cache POST from the poly-level PRE. + have h_cache_chunk : Spec.ntt_multiply_cache_post t1 z0 z1 z2 z3 t2 := by + unfold accumulating_ntt_multiply_poly_cache_post at h_cache + unfold Spec.ntt_multiply_cache_post + intro i_fin + have h := h_cache ⟨k.val, hk_16⟩ i_fin + -- h : ((cache.coefficients.val[k.val]!).elements.val[i_fin.val]!).val.natAbs ≤ 3328 + -- ∧ lift_fe_mont (...) = mul_pure (lift_fe_mont rhs[..]) (effective_zeta_fe i_fin ...) + -- where the effective zetas use `Spec.zeta_at (64+4*k.val+m)` — which by + -- `hz_m_lift` (`rfl`) equals `lift_fe_mont z_m`. + exact h + -- (7) Apply L2.8d use_cache to get s1 satisfying ntt_multiply_base_case_post. + have h_t_lhs : ∀ j : Fin 16, (t.elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro j; exact h_self ⟨k.val, hk_16⟩ j + have h_t1_rhs : ∀ j : Fin 16, (t1.elements.val[j.val]!).val.natAbs ≤ 3328 := by + intro j; exact h_rhs ⟨k.val, hk_16⟩ j + obtain ⟨s1, h_s1_eq, h_s1_len, h_s1_bnd, h_s1_post⟩ := + triple_exists_ok_fc + (accumulating_ntt_multiply_use_cache_fc t t1 s t2 z0 z1 z2 z3 h_s_len16 + h_t_lhs h_t1_rhs hz0_bnd hz1_bnd hz2_bnd hz3_bnd h_s_bnd h_cache_chunk) + -- s1's bound vs s lanes (s.val[k'] = acc_init[16k+k']). + have h_s1_bnd_abs : ∀ k' : Nat, k' < 16 → + (s1.val[k']!).val.natAbs ≤ (acc_init.val[16 * k.val + k']!).val.natAbs + 2^25 := by + intro k' hk' + have h_step_bnd := h_s1_bnd ⟨k', hk'⟩ + simp only at h_step_bnd + rw [h_s_lane_init k' hk'] at h_step_bnd + exact h_step_bnd + -- (8) Compose acc1 := back s1. + set acc1 : UseCacheFC.Acc := back s1 with hacc1_def + have h_acc1_val : acc1.val = acc.val.setSlice! i1.val s1.val := + h_back_eq s1 (by show s1.val.length = i3.val - i1.val; rw [← h_s_len_eq]; + show s1.length = s.length; rw [h_s1_len, h_s_len16]) + have h_acc1_len : acc1.val.length = 256 := by + rw [h_acc1_val, List.length_setSlice!, h_acc_len] + have h_acc1_in : ∀ ℓ : Nat, ℓ < 16 → + acc1.val[16 * k.val + ℓ]! = s1.val[ℓ]! := by + intro ℓ hℓ + rw [h_acc1_val] + have h_get : (acc.val.setSlice! i1.val s1.val)[16 * k.val + ℓ]! + = s1.val[(16 * k.val + ℓ) - i1.val]! := by + apply List.getElem!_setSlice!_middle + refine ⟨?_, ?_, ?_⟩ + · rw [hi1_val_eq]; omega + · rw [hi1_val_eq] + have h_sub' : 16 * k.val + ℓ - 16 * k.val = ℓ := by omega + rw [h_sub'] + show ℓ < s1.val.length + have h_s1' : s1.val.length = 16 := h_s1_len + rw [h_s1']; exact hℓ + · rw [h_acc_len] + have hk_15' : k.val ≤ 15 := by omega + have h1 : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 hk_15' + omega + rw [h_get] + have h_sub : (16 * k.val + ℓ) - i1.val = ℓ := by + rw [hi1_val_eq]; omega + rw [h_sub] + have h_acc1_out : ∀ n : Nat, n < 256 → + (n < 16 * k.val ∨ 16 * (k.val + 1) ≤ n) → + acc1.val[n]! = acc.val[n]! := by + intro n hn hcases + rw [h_acc1_val] + rcases hcases with hlt | hge + · apply List.getElem!_setSlice!_prefix + rw [hi1_val_eq]; exact hlt + · apply List.getElem!_setSlice!_suffix + rw [hi1_val_eq] + have h_s1' : s1.val.length = 16 := h_s1_len + rw [h_s1'] + have h_eq16 : 16 * k.val + 16 = 16 * (k.val + 1) := by ring + rw [h_eq16]; exact hge + -- (9) Body equation: the impl reduces to .ok (cont (..., acc1)). + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs cache + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc1)) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let t' ← Aeneas.Std.Array.index_usize myself.coefficients k + let t1' ← Aeneas.Std.Array.index_usize rhs.coefficients k + let i1' ← (k * 16#usize : Result Std.Usize) + let i2' ← k + 1#usize + let i3' ← i2' * 16#usize + let (s', index_mut_back) ← + core.Array.Insts.CoreOpsIndexIndexMut.index_mut + (core.Slice.Insts.CoreOpsIndexIndexMut + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + Std.I32)) acc { start := i1', «end» := i3' } + let t2' ← Aeneas.Std.Array.index_usize cache.coefficients k + let s1' ← + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations.accumulating_ntt_multiply_use_cache + t' t1' s' t2' + .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), index_mut_back s1'))) + : Result _) = _ + rw [h_idx_t]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_t1]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_t2]; simp only [Aeneas.Std.bind_tc_ok] + show ((do + let s1' ← + libcrux_iot_ml_kem.vector.portable.ntt.accumulating_ntt_multiply_use_cache + t t1 s t2 + .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), back s1'))) + : Result _) = _ + rw [h_s1_eq] + rfl + apply triple_of_ok_fc h_body + show UseCacheFC.step_post myself rhs acc_init k + (.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc1)) + unfold UseCacheFC.step_post + refine ⟨h_lt, rfl, hs_val_eq, ?_⟩ + show (UseCacheFC.inv myself rhs acc_init s_iter acc1).holds + -- Build the three invariant conjuncts (same shape as L6.3 base). + have h_inv_pure : + (∀ j : Nat, j < s_iter.val → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc1.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + ∧ (∀ j : Nat, s_iter.val ≤ j → j < 16 → ∀ ℓ : Nat, ℓ < 16 → + acc1.val[16 * j + ℓ]! = acc_init.val[16 * j + ℓ]!) + ∧ (∀ n : Nat, n < 256 → + (acc1.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + 2^25) := by + refine ⟨?_, ?_, ?_⟩ + · -- (a) Touched-chunk FC. + intro j hj ℓ hℓ + rw [hs_val_eq] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · have h_in_range : 16 * j + ℓ < 16 * k.val := by + have h1 : 16 * j + 16 ≤ 16 * k.val := by + have : j + 1 ≤ k.val := by omega + have : 16 * (j + 1) ≤ 16 * k.val := Nat.mul_le_mul_left 16 this + omega + omega + have h_lt_256 : 16 * j + ℓ < 256 := by + have : k.val ≤ 15 := by omega + have : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + have h_acc1_eq_acc : acc1.val[16 * j + ℓ]! = acc.val[16 * j + ℓ]! := + h_acc1_out (16 * j + ℓ) h_lt_256 (Or.inl h_in_range) + rw [h_acc1_eq_acc] + exact h_acc_done j hj_lt_k ℓ hℓ + · subst hj_eq_k + rw [h_acc1_in ℓ hℓ] + unfold ntt_multiply_base_case_post at h_s1_post + have h_lhs_val_eq : + (Spec.chunk_reducing_from_i32_array_pure s1).val[ℓ]! + = Spec.mont_reduce_pure (lift_fe_int (s1.val[ℓ]!).val) := + Spec.chunk_reducing_from_i32_array_pure_lane_eq s1 ℓ hℓ + have h_s_chunk_val : + (Spec.chunk_reducing_from_i32_array_pure s).val[ℓ]! + = Spec.mont_reduce_pure (lift_fe_int (s.val[ℓ]!).val) := + Spec.chunk_reducing_from_i32_array_pure_lane_eq s ℓ hℓ + have h_rhs_val_eq : + (ntt_multiply_base_case_alg + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3) + (Spec.chunk_reducing_from_i32_array_pure s)).val[ℓ]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Spec.chunk_reducing_from_i32_array_pure s).val[ℓ]!) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3)).val[ℓ]!) := by + unfold ntt_multiply_base_case_alg + exact Spec.chunk_add_pure_lane_eq _ _ ℓ hℓ + have h_chunk_eq : + Spec.mont_reduce_pure (lift_fe_int (s1.val[ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (s.val[ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3)).val[ℓ]!) := by + have h_eq : (Spec.chunk_reducing_from_i32_array_pure s1).val[ℓ]! + = (ntt_multiply_base_case_alg + (lift_chunk_mont t) (lift_chunk_mont t1) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3) + (Spec.chunk_reducing_from_i32_array_pure s)).val[ℓ]! := by + rw [h_s1_post] + rw [h_lhs_val_eq] at h_eq + rw [h_rhs_val_eq] at h_eq + rw [h_s_chunk_val] at h_eq + exact h_eq + rw [h_chunk_eq] + rw [h_s_lane_init ℓ hℓ] + rw [hz0_lift, hz1_lift, hz2_lift, hz3_lift] + · -- (b) Untouched acc chunks. + intro j hj_ge hj_lt ℓ hℓ + rw [hs_val_eq] at hj_ge + have h_n_lt_256 : 16 * j + ℓ < 256 := by + have : j ≤ 15 := by omega + have : 16 * j ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + have h_ge_range : 16 * (k.val + 1) ≤ 16 * j + ℓ := by + have : k.val + 1 ≤ j := hj_ge + have : 16 * (k.val + 1) ≤ 16 * j := Nat.mul_le_mul_left 16 this + omega + have h_acc1_eq_acc : acc1.val[16 * j + ℓ]! = acc.val[16 * j + ℓ]! := + h_acc1_out (16 * j + ℓ) h_n_lt_256 (Or.inr h_ge_range) + rw [h_acc1_eq_acc] + exact h_acc_undone j (by omega) hj_lt ℓ hℓ + · -- (c) Universal acc bound. + intro n hn + by_cases hcase : 16 * k.val ≤ n ∧ n < 16 * (k.val + 1) + · obtain ⟨hge, hlt⟩ := hcase + have hn_decomp : n = 16 * k.val + (n - 16 * k.val) := by omega + have hn_off_lt : n - 16 * k.val < 16 := by omega + have h_acc1_n : acc1.val[n]! = s1.val[n - 16 * k.val]! := by + conv_lhs => rw [hn_decomp] + exact h_acc1_in (n - 16 * k.val) hn_off_lt + rw [h_acc1_n] + have h_bnd_at_off := h_s1_bnd_abs (n - 16 * k.val) hn_off_lt + have h_acc_init_n : acc_init.val[16 * k.val + (n - 16 * k.val)]! + = acc_init.val[n]! := by + congr 1; omega + rw [h_acc_init_n] at h_bnd_at_off + exact h_bnd_at_off + · have h_outside : n < 16 * k.val ∨ 16 * (k.val + 1) ≤ n := by + by_contra hc + push Not at hc + exact hcase ⟨hc.1, hc.2⟩ + have h_acc1_eq_acc : acc1.val[n]! = acc.val[n]! := + h_acc1_out n hn h_outside + rw [h_acc1_eq_acc] + exact h_acc_bnd_rel n hn + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs cache + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_fc h_body + show UseCacheFC.step_post myself rhs acc_init k (.done acc) + unfold UseCacheFC.step_post + show (UseCacheFC.inv myself rhs acc_init 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (16#usize : Std.Usize).val → ∀ ℓ : Nat, ℓ < 16 → + Spec.mont_reduce_pure (lift_fe_int (acc.val[16 * j + ℓ]!).val) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (Spec.mont_reduce_pure (lift_fe_int (acc_init.val[16 * j + ℓ]!).val)) + ((Spec.ntt_multiply_pure_no_acc + (lift_chunk_mont (myself.coefficients.val[j]!)) + (lift_chunk_mont (rhs.coefficients.val[j]!)) + (Spec.zeta_at (64 + 4 * j)) + (Spec.zeta_at (64 + 4 * j + 1)) + (Spec.zeta_at (64 + 4 * j + 2)) + (Spec.zeta_at (64 + 4 * j + 3))).val[ℓ]!)) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → ∀ ℓ : Nat, ℓ < 16 → + acc.val[16 * j + ℓ]! = acc_init.val[16 * j + ℓ]!) + ∧ (∀ n : Nat, n < 256 → + (acc.val[n]!).val.natAbs ≤ (acc_init.val[n]!).val.natAbs + 2^25) := by + refine ⟨?_, ?_, ?_⟩ + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt ℓ hℓ + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt ℓ hℓ; rw [hk_eq]; exact hj_ge + · intro n hn; exact h_acc_bnd_rel n hn + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +/-- L6.3c — `polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache`: + polynomial wrapper of `accumulating_ntt_multiply_use_cache_fc`. The cache + is read-only here; the PRE asserts the cache satisfies + `accumulating_ntt_multiply_poly_cache_post` (so each chunk's vector-level + cache PRE is dischargeable from the `Spec.ntt_multiply_cache_post` + extraction at that chunk's effective zetas). + + POST identical to L6.3 base (length + relative bound + + `accumulating_ntt_multiply_poly_post`); no cache POST conjunct since + `_use_cache` does not write to the cache. -/ +@[spec] +theorem accumulating_ntt_multiply_use_cache_poly_fc + (myself rhs cache : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (accumulator : Std.Array Std.I32 256#usize) + (h_self : ∀ i : Fin 16, ∀ j : Fin 16, + ((myself.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_rhs : ∀ i : Fin 16, ∀ j : Fin 16, + ((rhs.coefficients.val[i.val]!).elements.val[j.val]!).val.natAbs ≤ 3328) + (h_acc_bnd : ∀ n : Fin 256, (accumulator.val[n.val]!).val.natAbs ≤ 2^30) + (h_cache : accumulating_ntt_multiply_poly_cache_post rhs cache) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache + (vectortraitsOperationsInst := portable_ops_inst) + myself rhs accumulator cache + ⦃ ⇓ r => ⌜ (∀ n : Fin 256, (r.val[n.val]!).val.natAbs + ≤ (accumulator.val[n.val]!).val.natAbs + 2^25) ∧ + accumulating_ntt_multiply_poly_post + myself rhs accumulator r ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vre]; simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.accumulating_ntt_multiply_use_cache_loop.body + (vectortraitsOperationsInst := portable_ops_inst) myself rhs cache iter1 acc1) + (β := UseCacheFC.Acc) + accumulator + 0#usize 16#usize + (UseCacheFC.inv myself rhs accumulator) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_⟩ + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _ _ _; trivial + · intro n _; omega) + ?_) + · -- Post entailment at k = 16: derive the locked POST. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (UseCacheFC.inv myself rhs accumulator 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + obtain ⟨h_done, _h_undone, h_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_holds + refine ⟨?_, ?_⟩ + · intro n; exact h_bnd n.val n.isLt + · unfold accumulating_ntt_multiply_poly_post + intro j hj ℓ hℓ + have h16' : (16#usize : Std.Usize).val = 16 := rfl + exact h_done j (by rw [h16']; exact hj) ℓ hℓ + · -- Step entailment. + intro acc k _h_ge h_le hinv + have h_step := accumulating_ntt_multiply_use_cache_poly_step_lemma_fc + myself rhs cache accumulator h_self h_rhs h_acc_bnd h_cache acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : UseCacheFC.step_post myself rhs accumulator k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [UseCacheFC.step_post] using hP + · have hP : UseCacheFC.step_post myself rhs accumulator k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [UseCacheFC.step_post] using hP + +end L6_3c_use_irreducible + + +end libcrux_iot_ml_kem.Polynomial.NttMultiply \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/PolyOps.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/PolyOps.lean new file mode 100644 index 00000000..c2c0d9da --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/PolyOps.lean @@ -0,0 +1,367 @@ +/- + # `Equivalence/L6_PolyOps.lean` — Layer 6 polynomial-composite Triples. + + L6.1: `PolynomialRingElement_poly_barrett_reduce_spec` — wraps L1.3 + `barrett_reduce_spec` across the 16 PortableVectors of a + `PolynomialRingElement`. The impl is a 16-iter loop over + `self.coefficients` that dispatches the trait method + `Vector::barrett_reduce` on each lane. + + Specialised to `Vector := PortableVector` with the concrete + `Libcrux_iot_ml_kemVectorTraitsOperations` instance. The instance's + `barrett_reduce` field (`@[reducible]`) reduces to + `vector.portable.arithmetic.barrett_reduce`, which is L1.3's target. +-/ +import LibcruxIotMlKem.Extraction.Funs +import LibcruxIotMlKem.Util.LoopSpecs +import LibcruxIotMlKem.Vector.Portable.Arithmetic.LoopHelper +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.Polynomial.PolyOps +open libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement +open CoreModels Aeneas Aeneas.Std Result ControlFlow Std.Do +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper + +/-! ## Inhabited instances — needed for `.val[j]!` projections on + `Array PortableVector 16`. Mirror the locally-registered instances + in `L3_NTTDrivers.lean`. Declared `local` to avoid colliding with + L3's identically-named auto-generated instance constants when both + files are imported by the project root. -/ + +local instance instInhabitedPortableVector_l6 : + Inhabited libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + ⟨{ elements := Std.Array.make 16#usize (List.replicate 16 (0#i16 : Std.I16)) + (by simp) }⟩ + +local instance instInhabitedPolynomialRingElement_l6 {Vector : Type} [Inhabited Vector] : + Inhabited (libcrux_iot_ml_kem.polynomial.PolynomialRingElement Vector) := + ⟨{ coefficients := Std.Array.make 16#usize (List.replicate 16 default) (by simp) }⟩ + +/-! ## Local helpers — Triple ↔ Result.ok bridges, pure-prop holds. + +Mirror the `triple_of_ok_l3` / `triple_exists_ok_l3` / `pure_prop_holds_l3` +family used by `L3_NTTDrivers.lean`. Each phase file carries its own copy +with a phase-local suffix to avoid cross-file shadowing. -/ + +private theorem triple_of_ok_l6 {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +private theorem triple_exists_ok_l6 {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +private theorem pure_prop_holds_l6 {P : Prop} (h : P) : (pure P : Result Prop).holds := by + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, WP.wp]; intro _; exact h + +private theorem of_pure_prop_holds_l6 {P : Prop} + (h : (pure P : Result Prop).holds) : P := by + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, WP.wp] at h; exact h trivial + +/-! ## L6.1 — `PolynomialRingElement_poly_barrett_reduce_spec` + +Driver loop: 16 iterations over `self.coefficients`. Each iteration reads +`self.coefficients[k]` (a `PortableVector`), dispatches +`OpsInst.barrett_reduce` (reduces via `@[reducible]` to +`vector.portable.arithmetic.barrett_reduce`, to which L1.3 applies), and +writes the reduced vector back. + +Per-vector bound `≤ 32767` (L1.3 input) → `≤ 3328` (L1.3 output). + +Loop invariant after `k` iterations (`k.val ∈ [0, 16]`), state `acc`: + - For `j < k.val`, all 16 elements of `acc.coefficients[j]` are + bounded by `3328`. + - For `j ≥ k.val`, `acc.coefficients[j] = re.coefficients[j]` (so the + L1.3 precondition `≤ 32767` is inherited from `h_pre`). -/ + +namespace BarrettReduce + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +/-- Step-local accumulator type. -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- Loop invariant. -/ +def inv + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 3328) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.coefficients.val[j]! = re.coefficients.val[j]!)) + +/-- Step post (named to keep the `match` constant canonical across sites). -/ +def step_post + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv re iter'.start acc').holds + | .done y => (inv re 16#usize y).holds + +end BarrettReduce + +/-- Per-iteration step lemma: each body call transforms + `acc.coefficients[k]` from a `≤ 32767` PortableVector (via h_pre + + h_acc_undone) to a `≤ 3328` one, leaves other indices untouched. -/ +private theorem poly_barrett_reduce_step_lemma + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 32767) + (acc : BarrettReduce.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_acc_done : ∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 3328) + (h_acc_undone : ∀ j : Nat, k.val ≤ j → j < 16 → + acc.coefficients.val[j]! = re.coefficients.val[j]!) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ BarrettReduce.step_post re k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.coefficients.length = 16 := + Std.Array.length_eq _ + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some round = k branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + -- 1) `Array.index_mut_usize acc.coefficients k`. + have h_idx : + Aeneas.Std.Array.index_usize acc.coefficients k + = .ok (acc.coefficients.val[k.val]!) := + array_index_usize_ok_eq acc.coefficients k (by rw [h_coef_len]; exact hk_16) + have h_imt_ok : + Aeneas.Std.Array.index_mut_usize acc.coefficients k + = .ok (acc.coefficients.val[k.val]!, acc.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx] + rfl + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.coefficients.val[k.val]! with ht_def + -- 2) `OpsInst.barrett_reduce t`. The instance's `@[reducible]` field + -- forwards to `vector.portable.arithmetic.barrett_reduce`, to which + -- L1.3 applies. Pre: t's elements ≤ 32767 (it's + -- `re.coefficients[k]` via h_acc_undone + h_pre). + have h_t_eq : t = re.coefficients.val[k.val]! := by + show acc.coefficients.val[k.val]! = re.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_pre k.val hk_16 ℓ hℓ + obtain ⟨t1, h_t1_eq, h_t1_P⟩ := + triple_exists_ok_l6 (barrett_reduce_spec t h_t_bd) + -- h_t1_P : ∀ i : Nat, i < 16 → modq_eq … ∧ (t1.elements[i]).val.natAbs ≤ 3328 + have h_t1_bd : ∀ ℓ : Nat, ℓ < 16 → + (t1.elements.val[ℓ]!).val.natAbs ≤ 3328 := fun ℓ hℓ => (h_t1_P ℓ hℓ).2 + -- Set the next-state accumulator. + set a : Std.Array + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.coefficients.set k t1 with ha_def + set acc' : BarrettReduce.Acc := ({ coefficients := a } : BarrettReduce.Acc) with hacc'_def + -- Compose the whole body into one `.ok` equation. + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc + = .ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp [bind_tc_ok, h_imt_ok] + -- After simp, only the barrett_reduce call remains. The trait + -- instance's field is `@[reducible]` and forwards to + -- `vector.portable.arithmetic.barrett_reduce`; force-reduce via + -- `show`, then close via `h_t1_eq`. + show (do + let t1' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce t + ok (cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := acc.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_t1_eq] + rfl + apply triple_of_ok_l6 h_body + show BarrettReduce.step_post re k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc')) + unfold BarrettReduce.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + -- Now: invariant at (s, acc'). + apply pure_prop_holds_l6 + -- Two conjuncts of BarrettReduce.inv at (s, acc'). + refine ⟨?_, ?_⟩ + · -- All j < s.val are bounded by 3328. + intro j hj ℓ hℓ + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: unchanged by the set, use h_acc_done. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : + (acc.coefficients.set k t1)[j]! = (acc.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.coefficients.set k t1).val[j]! = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show ((acc.coefficients.set k t1).val[j]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_ne_val] + exact h_acc_done j hj_lt_k ℓ hℓ + · -- j = k.val: it's t1. + subst hj_eq_k + have h_lt' : k.val < acc.coefficients.length := by + rw [h_coef_len]; exact hk_16 + have h_set_eq : + (acc.coefficients.set k t1)[k.val]! = t1 := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.coefficients k k.val t1 + ⟨rfl, h_lt'⟩ + have h_set_eq_val : + (acc.coefficients.set k t1).val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show ((acc.coefficients.set k t1).val[k.val]!).elements.val[ℓ]!.val.natAbs ≤ _ + rw [h_set_eq_val] + exact h_t1_bd ℓ hℓ + · -- All j ≥ s.val are unchanged. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : + (acc.coefficients.set k t1)[j]! = (acc.coefficients)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + have h_set_ne_val : + (acc.coefficients.set k t1).val[j]! = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.coefficients.set k t1).val[j]! = re.coefficients.val[j]! + rw [h_set_ne_val] + exact h_acc_undone j h_ge' hj_lt + · -- None branch (k ≥ 16). + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + { start := k, «end» := 16#usize } acc + = .ok (done acc) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_l6 h_body + show BarrettReduce.step_post re k (.done acc) + unfold BarrettReduce.step_post + show (BarrettReduce.inv re 16#usize acc).holds + apply pure_prop_holds_l6 + refine ⟨?_, ?_⟩ + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt; rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + +set_option maxHeartbeats 16000000 in +@[spec] +theorem PolynomialRingElement_poly_barrett_reduce_spec + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_pre : ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((re.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 32767) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + re + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → ∀ j : Nat, j < 16 → + ((r.coefficients.val[i]!).elements.val[j]!).val.natAbs ≤ 3328 ⌝ ⦄ := by + -- Reduce the top wrapper to the inner loop. + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce + -- `VECTORS_IN_RING_ELEMENT` reduces to `.ok 16#usize`. + have h_vire : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vire] + simp only [bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + iter1 acc1) + (β := BarrettReduce.Acc) + re + 0#usize 16#usize + (BarrettReduce.inv re) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_l6 ⟨ + fun j hj _ _ => absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_done, _h_undone⟩ := of_pure_prop_holds_l6 h + intro i hi j hj + apply h_done i (by rw [show (16#usize : Std.Usize).val = 16 from rfl]; exact hi) j hj + · -- Step lemma application. + intro acc k h_ge h_le hinv + obtain ⟨h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_l6 hinv + have h_step := poly_barrett_reduce_step_lemma re h_pre acc k h_le + h_acc_done h_acc_undone + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : BarrettReduce.step_post re k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [BarrettReduce.step_post] using hP + · have hP : BarrettReduce.step_post re k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [BarrettReduce.step_post] using hP + +end libcrux_iot_ml_kem.Polynomial.PolyOps \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/PolyOpsFc.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/PolyOpsFc.lean new file mode 100644 index 00000000..f786ce94 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/PolyOpsFc.lean @@ -0,0 +1,3077 @@ +/- + # `Polynomial/PolyOpsFc.lean` — FC theorems for §L6.{2,4,5,6,7}. + + Houses reducing_from_i32_array_fc, subtract_reduce_fc, and the + add_*_error_reduce_fc family. Lives separately from PolyOps.lean + (and from PolyOpsFcBarrett.lean which holds §L6.1) to keep the + layering acyclic. +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Spec.Pure +import LibcruxIotMlKem.Spec.ModularArith +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Extraction.Funs +import HacspecMlKem.Extraction.Funs + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + + +/-! ### Extracted from FCTargets.lean (§poly_l6_rest). -/ + +namespace libcrux_iot_ml_kem.Polynomial.PolyOpsFc +open libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +namespace ReducingFromI32ArrayFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Step-local accumulator (the mutable `b` poly). -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `subtract_reduce_fc`. + * (a) Chunks `j < k`: FC equation `lift_chunk acc[j] = chunk_subtract_reduce_pure + (lift_chunk self[j]) (lift_chunk b_init[j])`. + * (b) Chunks `k ≤ j < 16`: `acc[j] = b_init[j]` (unchanged). -/ +def inv + (self b_init : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + lift_chunk (acc.coefficients.val[j]!) + = Spec.chunk_subtract_reduce_pure + (lift_chunk (self.coefficients.val[j]!)) + (lift_chunk (b_init.coefficients.val[j]!))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.coefficients.val[j]! = b_init.coefficients.val[j]!)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (self b_init : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv self b_init iter'.start acc').holds + | .done y => (inv self b_init 16#usize y).holds + +end ReducingFromI32ArrayFC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for `subtract_reduce`. Given a valid loop + state `(acc, k)` with `k.val < 16`, applies the fused + `mont_mul(1441) + sub + negate + barrett` chain to chunk `k.val` of + `acc`, recording the FC equation `lift_chunk acc'[k.val] = + chunk_subtract_reduce_pure (lift_chunk self[k.val]) (lift_chunk + b_init[k.val])` while preserving chunks `j ≠ k.val`. -/ +theorem subtract_reduce_step_lemma_fc + (self b_init : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((self.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (h_b_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((b_init.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) + (acc : ReducingFromI32ArrayFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (ReducingFromI32ArrayFC.inv self b_init k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) self + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ ReducingFromI32ArrayFC.step_post self b_init k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.coefficients.length = 16 := + Std.Array.length_eq _ + have h_self_coef_len : self.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_acc_done, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some i = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) `index_mut_usize b.coefficients k` → `(t, set_back) = (acc.coefs[k], acc.coefs.set k)`. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.coefficients.val[k.val]! with ht_def + have h_idx_t : Aeneas.Std.Array.index_usize acc.coefficients k = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.coefficients k + (by rw [h_coef_len]; exact hk_16) + have h_imt_t : Aeneas.Std.Array.index_mut_usize acc.coefficients k + = .ok (t, acc.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t]; rfl + -- (1a) `t = b_init.coefficients[k]` (via h_acc_undone at j=k). + have h_t_eq : t = b_init.coefficients.val[k.val]! := by + show acc.coefficients.val[k.val]! = b_init.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_b_bnd k.val hk_16 ℓ hℓ + -- (2) `mont_mul(t, 1441#i16)` → `t1`. Pre: |1441| ≤ 1664 ✓; |t| ≤ 32767 ✓. + have h_c1441_bnd : ((1441#i16 : Std.I16).val.natAbs) ≤ 1664 := by decide + obtain ⟨t1, h_t1_eq, h_t1_lift_mont⟩ := + triple_exists_ok_fc + (montgomery_multiply_by_constant_fc t (1441#i16) h_t_bnd h_c1441_bnd) + -- Also pull the legacy per-element fact for the bound and value. + have h_t1_spec := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.montgomery_multiply_by_constant_spec + t (1441#i16) h_c1441_bnd + obtain ⟨t1', h_t1_eq', h_t1_per⟩ := triple_exists_ok_fc h_t1_spec + have h_t1_same : t1 = t1' := by + have := h_t1_eq.symm.trans h_t1_eq' + cases this; rfl + subst h_t1_same + have h_t1_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t1.elements.val[ℓ]!).val.natAbs ≤ 3328 := by + intro ℓ hℓ; exact (h_t1_per ℓ hℓ).1 + have h_t1_modq : ∀ ℓ : Nat, ℓ < 16 → + ((t1.elements.val[ℓ]!).val * (2 ^ 16 : Int)) % 3329 + = ((t.elements.val[ℓ]!).val * (1441#i16 : Std.I16).val) % 3329 := by + intro ℓ hℓ; exact (h_t1_per ℓ hℓ).2 + -- (3) `a = acc.coefficients.set k t1` (after applying index_mut_back). + set a : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.coefficients.set k t1 with ha_def + -- (4) `index_mut_usize a k` → `(t2, set_back2) = (a[k], a.set k)`. + have h_a_len : a.length = 16 := by simp [ha_def, h_coef_len] + have h_a_k : a.val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + have h_idx_t2 : Aeneas.Std.Array.index_usize a k = .ok (a.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a k + (by rw [h_a_len]; exact hk_16) + have h_imt_t2 : Aeneas.Std.Array.index_mut_usize a k = .ok (t1, a.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t2]; rw [h_a_k]; rfl + -- (5) `index_usize self.coefficients k` → `t3 = self.coefs[k]`. + set t3 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + self.coefficients.val[k.val]! with ht3_def + have h_idx_t3 : Aeneas.Std.Array.index_usize self.coefficients k = .ok t3 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq self.coefficients k + (by rw [h_self_coef_len]; exact hk_16) + have h_t3_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t3.elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro ℓ hℓ; exact h_self_bnd k.val hk_16 ℓ hℓ + -- (6) `sub t1 t3` → `t4`. Pre: |t1[ℓ] - t3[ℓ]| ≤ 32767. + -- |t1| ≤ 3328, |t3| ≤ 29439, so |t1 - t3| ≤ 3328 + 29439 = 32767 ✓. + have h_sub_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((t1.elements.val[ℓ]!).val - (t3.elements.val[ℓ]!).val : Int).natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb_t1 := h_t1_bnd ℓ hℓ + have hb_t3 := h_t3_bnd ℓ hℓ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2] + have h_abs_sub : ((t1.elements.val[ℓ]!).val + - (t3.elements.val[ℓ]!).val : Int).natAbs + ≤ ((t1.elements.val[ℓ]!).val : Int).natAbs + + ((t3.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_sub_le _ _ + omega + obtain ⟨t4, h_t4_eq, h_t4_lift⟩ := + triple_exists_ok_fc (sub_fc t1 t3 h_sub_bnd) + -- Pull legacy per-element value: t4[ℓ].val = t1[ℓ].val - t3[ℓ].val. + have h_t4_spec := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.sub_spec t1 t3 h_sub_bnd + obtain ⟨t4', h_t4_eq', h_t4_per⟩ := triple_exists_ok_fc h_t4_spec + have h_t4_same : t4 = t4' := by + have := h_t4_eq.symm.trans h_t4_eq' + cases this; rfl + subst h_t4_same + have h_t4_val : ∀ ℓ : Nat, ℓ < 16 → + (t4.elements.val[ℓ]!).val + = (t1.elements.val[ℓ]!).val - (t3.elements.val[ℓ]!).val := by + intro ℓ hℓ; exact (h_t4_per ℓ hℓ).1 + have h_t4_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t4.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ; exact (h_t4_per ℓ hℓ).2 + -- (7) `a1 = a.set k t4`. + set a1 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a.set k t4 with ha1_def + have h_a1_len : a1.length = 16 := by simp [ha1_def, h_a_len] + have h_a1_k : a1.val[k.val]! = t4 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq a k k.val t4 + ⟨rfl, by rw [h_a_len]; exact hk_16⟩ + -- (8) `index_mut_usize a1 k` → `(t5, set_back3) = (t4, a1.set k)`. + have h_idx_t5 : Aeneas.Std.Array.index_usize a1 k = .ok (a1.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a1 k + (by rw [h_a1_len]; exact hk_16) + have h_imt_t5 : Aeneas.Std.Array.index_mut_usize a1 k = .ok (t4, a1.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t5]; rw [h_a1_k]; rfl + -- (9) `negate t4` → `t6`. Pre: |t4[ℓ]| ≤ 32767 ≤ 2^15-1 ✓. + have h_neg_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t4.elements.val[ℓ]!).val.natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have h_b := h_t4_bnd ℓ hℓ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2]; exact h_b + obtain ⟨t6, h_t6_eq, h_t6_lift⟩ := + triple_exists_ok_fc (negate_fc t4 h_neg_bnd) + -- Pull legacy per-element BV fact for t6. + have h_t6_spec := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.negate_spec t4 + obtain ⟨t6', h_t6_eq', h_t6_per⟩ := triple_exists_ok_fc h_t6_spec + have h_t6_same : t6 = t6' := by + have := h_t6_eq.symm.trans h_t6_eq' + cases this; rfl + subst h_t6_same + -- We need a bound on t6 for barrett's pre. Derive from negate_spec's BV + -- equality combined with h_t4_bnd: t6.val = -t4.val (under |t4| ≤ 2^15-1). + have h_t6_val : ∀ ℓ : Nat, ℓ < 16 → + (t6.elements.val[ℓ]!).val = -(t4.elements.val[ℓ]!).val := by + intro ℓ hℓ + set xi : Std.I16 := t4.elements.val[ℓ]! with hxi + set ri : Std.I16 := t6.elements.val[ℓ]! with hri + have h_bv : ri.bv = -xi.bv := h_t6_per ℓ hℓ + have h_wsub_bv : + (Aeneas.Std.I16.wrapping_sub (0#i16) xi).bv = -xi.bv := by + rw [Aeneas.Std.I16.wrapping_sub_bv_eq] + simp only [show (0#i16 : Std.I16).bv = (0 : BitVec 16) from rfl] + exact BitVec.zero_sub xi.bv + have h_step1 : ri.val = (Aeneas.Std.I16.wrapping_sub (0#i16) xi).val := by + have h_toInt : (ri.bv).toInt + = (Aeneas.Std.I16.wrapping_sub (0#i16) xi).bv.toInt := by + rw [h_bv, h_wsub_bv] + have h_lhs : (ri.bv).toInt = ri.val := Aeneas.Std.I16.bv_toInt_eq ri + have h_rhs : (Aeneas.Std.I16.wrapping_sub (0#i16) xi).bv.toInt + = (Aeneas.Std.I16.wrapping_sub (0#i16) xi).val := + Aeneas.Std.I16.bv_toInt_eq _ + rw [h_lhs, h_rhs] at h_toInt + exact h_toInt + rw [h_step1, Aeneas.Std.I16.wrapping_sub_val_eq] + have h0 : (0#i16 : Std.I16).val = 0 := by decide + rw [h0] + have h_diff : (0 : Int) - xi.val = -xi.val := by ring + rw [h_diff] + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_abs : xi.val.natAbs ≤ 2^15 - 1 := h_neg_bnd ℓ hℓ + have h_pow : -((2 : Int) ^ (16 - 1)) = -(2^15 : Int) := by decide + rw [h_pow] + omega + · have h_abs : xi.val.natAbs ≤ 2^15 - 1 := h_neg_bnd ℓ hℓ + have h_pow : ((2 : Int) ^ (16 - 1)) = (2^15 : Int) := by decide + rw [h_pow] + omega + have h_t6_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t6.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + have hv := h_t6_val ℓ hℓ + have hb := h_t4_bnd ℓ hℓ + have h_abs : ((-(t4.elements.val[ℓ]!).val : Int)).natAbs + = ((t4.elements.val[ℓ]!).val : Int).natAbs := Int.natAbs_neg _ + rw [hv, h_abs]; exact hb + -- (10) `a2 = a1.set k t6`. + set a2 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a1.set k t6 with ha2_def + have h_a2_len : a2.length = 16 := by simp [ha2_def, h_a1_len] + have h_a2_k : a2.val[k.val]! = t6 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq a1 k k.val t6 + ⟨rfl, by rw [h_a1_len]; exact hk_16⟩ + -- (11) `index_mut_usize a2 k` → `(t7, set_back4) = (t6, a2.set k)`. + have h_idx_t7 : Aeneas.Std.Array.index_usize a2 k = .ok (a2.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a2 k + (by rw [h_a2_len]; exact hk_16) + have h_imt_t7 : Aeneas.Std.Array.index_mut_usize a2 k = .ok (t6, a2.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t7]; rw [h_a2_k]; rfl + -- (12) `barrett_reduce t6` → `t8`. Pre: |t6[ℓ]| ≤ 32767 ✓. + obtain ⟨t8, h_t8_eq, h_t8_post⟩ := + triple_exists_ok_fc (barrett_reduce_fc t6 h_t6_bnd) + obtain ⟨h_t8_bnd, h_t8_lift⟩ := h_t8_post + -- (13) Compose acc' = `{ coefficients := a2.set k t8 }`. + set a3 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a2.set k t8 with ha3_def + set acc' : ReducingFromI32ArrayFC.Acc := { coefficients := a3 } with hacc'_def + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) self + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + -- The body is now sequentially: index_mut acc k → mont_mul → index_mut → index_usize self k + -- → sub → index_mut → negate → index_mut → barrett. Discharge each. + show (do + let (t', index_mut_back) ← + Aeneas.Std.Array.index_mut_usize acc.coefficients k + let t1' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant + t' (1441#i16) + let (t2', index_mut_back1) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back t1') k + let t3' ← Aeneas.Std.Array.index_usize self.coefficients k + let t4' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.sub + t2' t3' + let (t5', index_mut_back2) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back1 t4') k + let t6' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.negate t5' + let (t7', index_mut_back3) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back2 t6') k + let t8' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce t7' + .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := index_mut_back3 t8' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_imt_t]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t2]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_t3]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t5]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t6_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t7]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t8_eq] + rfl + apply triple_of_ok_fc h_body + show ReducingFromI32ArrayFC.step_post self b_init k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold ReducingFromI32ArrayFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (ReducingFromI32ArrayFC.inv self b_init s acc').holds + -- Invariant at (s, acc'). acc'.coefficients = a3 = ((acc.coefs.set k t1).set k t4).set k t6).set k t8. + -- Equivalently: only chunk k changes, to t8. + have h_inv_pure : + (∀ j : Nat, j < s.val → + lift_chunk (acc'.coefficients.val[j]!) + = Spec.chunk_subtract_reduce_pure + (lift_chunk (self.coefficients.val[j]!)) + (lift_chunk (b_init.coefficients.val[j]!))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.coefficients.val[j]! = b_init.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · -- (a) j < s.val → FC equation at chunk j. + intro j hj + rw [hs_val] at hj + show lift_chunk (((((acc.coefficients.set k t1).set k t4).set k t6).set k t8).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: chunk j unchanged through all four sets. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set1 : (((((acc.coefficients.set k t1).set k t4).set k t6).set k t8).val[j]!) + = ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (((acc.coefficients.set k t1).set k t4).set k t6) k j t8 h_ne + have h_set2 : ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) + = (((acc.coefficients.set k t1).set k t4).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + ((acc.coefficients.set k t1).set k t4) k j t6 h_ne + have h_set3 : (((acc.coefficients.set k t1).set k t4).val[j]!) + = ((acc.coefficients.set k t1).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (acc.coefficients.set k t1) k j t4 h_ne + have h_set4 : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set1, h_set2, h_set3, h_set4] + exact h_acc_done j hj_lt_k + · -- j = k.val: chunk j = t8, need lift_chunk t8 = chunk_subtract_reduce_pure .... + subst hj_eq_k + have h_set_eq : ((((acc.coefficients.set k t1).set k t4).set k t6).set k t8).val[k.val]! + = t8 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq + (((acc.coefficients.set k t1).set k t4).set k t6) k k.val t8 + ⟨rfl, by simp; exact hk_16⟩ + rw [h_set_eq] + -- Goal: lift_chunk t8 = chunk_subtract_reduce_pure (lift_chunk self[k]) (lift_chunk b_init[k]). + -- We'll prove this by chaining: lift_chunk t8 = lift_chunk t6 (barrett identity) = + -- chunk_neg (lift_chunk t4) = chunk_neg (chunk_sub (lift_chunk t1) (lift_chunk t3)). + -- Then expand: lift_fe t1[ℓ] = mul_pure(lift_fe t[ℓ], lift_fe_mont 1441), + -- and t[ℓ] = b_init[k][ℓ], t3[ℓ] = self[k][ℓ]. + -- The final identity reduces to: neg_pure(sub_pure(mul_pure b z) self) = sub_pure self (mul_pure b z). + rw [h_t8_lift] + -- t8_lift: lift_chunk t8 = chunk_barrett_reduce_pure (lift_chunk t6). + show Spec.chunk_barrett_reduce_pure (lift_chunk t6) + = Spec.chunk_subtract_reduce_pure + (lift_chunk (self.coefficients.val[k.val]!)) + (lift_chunk (b_init.coefficients.val[k.val]!)) + -- Need a pointwise lane equation showing the four-stage chain on lane ℓ + -- equals: barrett(neg(t1[ℓ] - self[ℓ])) ≡ self[ℓ] - mul_pure(b[ℓ], lift_fe_mont 1441) in ZMod q. + unfold Spec.chunk_barrett_reduce_pure Spec.chunk_subtract_reduce_pure + apply Subtype.ext + -- Goal now: .val of LHS = .val of RHS. The .val of Std.Array.make is the list. + change (List.range 16).map (fun i => + Spec.barrett_pure ((lift_chunk t6).val[i]!)) + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((lift_chunk (self.coefficients.val[k.val]!)).val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk (b_init.coefficients.val[k.val]!)).val[ℓ]!) + (lift_fe_mont (1441#i16)))) + apply List.ext_getElem + · simp + · intro ℓ hℓ1 _hℓ2 + have hℓ : ℓ < 16 := by + have : ℓ < (List.range 16).length := by simpa using hℓ1 + simpa using this + rw [List.getElem_map, List.getElem_range, + List.getElem_map, List.getElem_range] + -- LHS: Spec.barrett_pure ((lift_chunk t6).val[ℓ]!). + -- RHS: sub_pure (self[k][ℓ]) (mul_pure b[k][ℓ] (lift_fe_mont 1441)). + -- Step A: ((lift_chunk t6).val[ℓ]!) = lift_fe (t6.elements.val[ℓ]!). + have h_t6_elems_len : t6.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length t6 + have h_lc_t6 : ((lift_chunk t6).val[ℓ]!) = lift_fe (t6.elements.val[ℓ]!) := by + unfold lift_chunk + show ((t6.elements.val.map lift_fe)[ℓ]!) = _ + have h_len : (t6.elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_t6_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos t6.elements.val ℓ (by rw [h_t6_elems_len]; exact hℓ)] + rw [h_lc_t6] + -- Step B: barrett_pure (lift_fe t6[ℓ]) = lift_fe t6[ℓ] (canonical identity). + rw [barrett_pure_lift_fe] + -- Step C: lift_fe t6[ℓ] = -(lift_fe t4[ℓ]) in FE ⟺ ZMod equation. + -- Use h_t6_val: t6.val[ℓ] = -t4.val[ℓ] (Int). Use h_t4_val: t4.val[ℓ] = t1.val[ℓ] - t3.val[ℓ]. + -- So t6.val[ℓ] = -(t1.val[ℓ] - t3.val[ℓ]) = t3.val[ℓ] - t1.val[ℓ]. + -- Then lift_fe t6 = (t3.val - t1.val : ZMod q). + -- RHS: sub_pure self[k][ℓ] (mul_pure b_init[k][ℓ] (lift_fe_mont 1441)). + -- = self[k][ℓ] - b_init[k][ℓ] * (1441 * 169) in ZMod q. + -- Since self[k][ℓ] = t3.val and t1 satisfies t1.val ≡ t.val * 1441 * 169 ≡ b_init[k].val * 1441 * 169 in ZMod q, + -- lift_fe t6 = t3.val - t1.val ≡ self[k][ℓ] - b[k][ℓ] * 1441 * 169 ✓. + -- Let me build this step by step. + -- (i) lift_fe t6[ℓ] = sub_pure (lift_fe t3[ℓ]) (lift_fe t1[ℓ]) + have hv6 := h_t6_val ℓ hℓ + have hv4 := h_t4_val ℓ hℓ + have h_t6_val_eq : (t6.elements.val[ℓ]!).val + = (t3.elements.val[ℓ]!).val - (t1.elements.val[ℓ]!).val := by + rw [hv6, hv4]; ring + have h_sub_bnd_local : + ((t3.elements.val[ℓ]!).val - (t1.elements.val[ℓ]!).val : Int).natAbs ≤ 2^15 - 1 := by + have h_bt1 := h_t1_bnd ℓ hℓ + have h_bt3 := h_t3_bnd ℓ hℓ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2] + have h_abs_sub : + ((t3.elements.val[ℓ]!).val - (t1.elements.val[ℓ]!).val : Int).natAbs + ≤ ((t3.elements.val[ℓ]!).val : Int).natAbs + + ((t1.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_sub_le _ _ + omega + have h_lift_t6 : + lift_fe (t6.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe (t3.elements.val[ℓ]!)) (lift_fe (t1.elements.val[ℓ]!)) := + lift_fe_sub_pure_eq _ _ _ h_t6_val_eq + rw [h_lift_t6] + -- (ii) lift_fe t1[ℓ] = mul_pure (lift_fe t[ℓ]) (lift_fe_mont 1441). + have h_lift_t1 : + lift_fe (t1.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe (t.elements.val[ℓ]!)) (lift_fe_mont (1441#i16)) := + lift_fe_mont_mul_pure_eq _ _ _ (h_t1_modq ℓ hℓ) + rw [h_lift_t1] + -- (iii) Now goal: + -- sub_pure (lift_fe t3[ℓ]) (mul_pure (lift_fe t[ℓ]) (lift_fe_mont 1441)) + -- = sub_pure ((lift_chunk self[k]).val[ℓ]!) (mul_pure ((lift_chunk b_init[k]).val[ℓ]!) (lift_fe_mont 1441)). + -- t3 = self[k]; t = b_init[k]; lift_chunk x .val[ℓ]! = lift_fe (x.elements.val[ℓ]!). + have h_self_elems_len : (self.coefficients.val[k.val]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_b_elems_len : (b_init.coefficients.val[k.val]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_lc_self : ((lift_chunk (self.coefficients.val[k.val]!)).val[ℓ]!) + = lift_fe ((self.coefficients.val[k.val]!).elements.val[ℓ]!) := by + unfold lift_chunk + show (((self.coefficients.val[k.val]!).elements.val.map lift_fe)[ℓ]!) = _ + have h_len : + ((self.coefficients.val[k.val]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_self_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos (self.coefficients.val[k.val]!).elements.val ℓ + (by rw [h_self_elems_len]; exact hℓ)] + have h_lc_b : ((lift_chunk (b_init.coefficients.val[k.val]!)).val[ℓ]!) + = lift_fe ((b_init.coefficients.val[k.val]!).elements.val[ℓ]!) := by + unfold lift_chunk + show (((b_init.coefficients.val[k.val]!).elements.val.map lift_fe)[ℓ]!) = _ + have h_len : + ((b_init.coefficients.val[k.val]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_b_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos (b_init.coefficients.val[k.val]!).elements.val ℓ + (by rw [h_b_elems_len]; exact hℓ)] + rw [h_lc_self, h_lc_b] + -- t3 = self[k], t = b_init[k] (since t = acc[k] = b_init[k] by h_t_eq). + -- Rewrite t3 and t to their respective self[k] and b_init[k] definitions + -- to match the RHS. + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe (t3.elements.val[ℓ]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe (t.elements.val[ℓ]!)) (lift_fe_mont (1441#i16))) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe ((self.coefficients.val[k.val]!).elements.val[ℓ]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe ((b_init.coefficients.val[k.val]!).elements.val[ℓ]!)) + (lift_fe_mont (1441#i16))) + rw [show t3 = self.coefficients.val[k.val]! from ht3_def, + show t = b_init.coefficients.val[k.val]! from h_t_eq] + · -- (b) s.val ≤ j < 16 → acc'.coefs[j] = b_init.coefs[j]. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + show ((((acc.coefficients.set k t1).set k t4).set k t6).set k t8).val[j]! + = b_init.coefficients.val[j]! + have h_set1 : (((((acc.coefficients.set k t1).set k t4).set k t6).set k t8).val[j]!) + = ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (((acc.coefficients.set k t1).set k t4).set k t6) k j t8 h_ne + have h_set2 : ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) + = (((acc.coefficients.set k t1).set k t4).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + ((acc.coefficients.set k t1).set k t4) k j t6 h_ne + have h_set3 : (((acc.coefficients.set k t1).set k t4).val[j]!) + = ((acc.coefficients.set k t1).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (acc.coefficients.set k t1) k j t4 h_ne + have h_set4 : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set1, h_set2, h_set3, h_set4] + exact h_acc_undone j h_ge' hj_lt + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) self + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_fc h_body + show ReducingFromI32ArrayFC.step_post self b_init k (.done acc) + unfold ReducingFromI32ArrayFC.step_post + show (ReducingFromI32ArrayFC.inv self b_init 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.coefficients.val[j]!) + = Spec.chunk_subtract_reduce_pure + (lift_chunk (self.coefficients.val[j]!)) + (lift_chunk (b_init.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.coefficients.val[j]! = b_init.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L6.2 — `subtract_reduce`: per-chunk `negate(mont_mul(b, 1441) - self)` + then barrett. Equivalent in ZMod q to pointwise `self - 512 · b` + (C.4 commute: `1441 · R⁻¹ ≡ 512 mod q`), NOT to hacspec's `self - b`. + + Spec target: custom `Spec.subtract_reduce_pure` modeling the + fused-Mont impl behavior (see §0.5). Mirrors the L6.4/5/6 + `Spec.add_*_reduce_pure` pattern. + + **Preconditions** (load-bearing, beyond the locked True-pre form): + - `h_self_bnd`: per-lane `|self[k][ℓ]| ≤ 29439` (drives `sub`'s overflow bound). + - `h_b_bnd`: per-lane `|b[k][ℓ]| ≤ 32767` (consumed by `mont_mul`'s + legacy precondition; the impl's later `sub` then uses `|t1| ≤ 3328` from + mont's output bound). -/ +@[spec] +theorem subtract_reduce_fc + (self b : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((self.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (h_b_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((b.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce + (vectortraitsOperationsInst := portable_ops_inst) self b + ⦃ ⇓ p => ⌜ lift_poly p + = Spec.subtract_reduce_pure (lift_poly self) (lift_poly b) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce + -- Resolve `VECTORS_IN_RING_ELEMENT = .ok 16#usize`. + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vre]; simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, b1) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.subtract_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) self iter1 b1) + (β := ReducingFromI32ArrayFC.Acc) + b + 0#usize 16#usize + (ReducingFromI32ArrayFC.inv self b) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_⟩ + · -- No chunks done yet. + intro j hj; exact absurd hj (Nat.not_lt_zero j) + · -- All chunks unchanged (goal trivializes since acc = b). + intro _ _ _; trivial) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (ReducingFromI32ArrayFC.inv self b 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.coefficients.val[j]!) + = Spec.chunk_subtract_reduce_pure + (lift_chunk (self.coefficients.val[j]!)) + (lift_chunk (b.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.coefficients.val[j]! = b.coefficients.val[j]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + ReducingFromI32ArrayFC.inv] using h_inv_holds + obtain ⟨h_done, _h_undone⟩ := h_inv + -- Build chunks_arr matching the Spec definition, then apply + -- flatten_chunks_eq_lift_poly_fc. + unfold Spec.subtract_reduce_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_subtract_reduce_pure + (Spec.chunk_at (lift_poly self) k) + (Spec.chunk_at (lift_poly b) k))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_subtract_reduce_pure + (Spec.chunk_at (lift_poly self) k) + (Spec.chunk_at (lift_poly b) k)))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc self k hk, chunk_at_lift_poly_fc b k hk] + exact (h_done k hk).symm + have h_final := flatten_chunks_eq_lift_poly_fc r chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Step lemma application. + intro acc k _h_ge h_le hinv + have h_step := + subtract_reduce_step_lemma_fc self b h_self_bnd h_b_bnd acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : ReducingFromI32ArrayFC.step_post self b k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [ReducingFromI32ArrayFC.step_post] using hP + · have hP : ReducingFromI32ArrayFC.step_post self b k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [ReducingFromI32ArrayFC.step_post] using hP + +/-! ### L6.3 — `add_to_ring_element` (DOCUMENTED, NO STANDALONE FC). + + The impl-side `PolynomialRingElement.add_to_ring_element` is NOT + a standalone exported op at the impl extraction layer: the impl + fuses "add then reduce" into the `add_*_reduce` family + (`add_error_reduce`, `add_standard_error_reduce`, + `add_message_error_reduce`). + + The hacspec target `polynomial.add_to_ring_element` is exercised + indirectly through the matrix-level FCs (L7.1 / L7.3 use it + inside `multiply_matrix_by_column` and `add_polynomials`); we do + NOT land a separate `add_to_ring_element_fc` Triple here, but we + DO land per-component `add_*_reduce_fc` Triples below (L6.4/5/6) + that cover the impl-side calls. -/ + +/-! ### L6.4.A — Loop scaffolding for `add_error_reduce_fc`. -/ + +namespace AddErrorReduceFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Step-local accumulator (the mutable `self` poly). -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `add_error_reduce_fc`. + * (a) Chunks `j < k`: FC equation `lift_chunk acc[j] = + chunk_add_error_reduce_pure (lift_chunk self_init[j]) (lift_chunk error[j])`. + * (b) Chunks `k ≤ j < 16`: `acc[j] = self_init[j]` (unchanged). -/ +def inv + (self_init error : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + lift_chunk (acc.coefficients.val[j]!) + = Spec.chunk_add_error_reduce_pure + (lift_chunk (self_init.coefficients.val[j]!)) + (lift_chunk (error.coefficients.val[j]!))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.coefficients.val[j]! = self_init.coefficients.val[j]!)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (self_init error : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv self_init error iter'.start acc').holds + | .done y => (inv self_init error 16#usize y).holds + +end AddErrorReduceFC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for `add_error_reduce`. Given a valid loop + state `(acc, k)` with `k.val < 16`, applies the + `mont_mul(1441) + add(error[k]) + barrett` chain to chunk `k.val` of + `acc`, recording the FC equation `lift_chunk acc'[k.val] = + chunk_add_error_reduce_pure (lift_chunk self_init[k.val]) (lift_chunk + error[k.val])` while preserving chunks `j ≠ k.val`. -/ +theorem add_error_reduce_step_lemma_fc + (self_init error : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((self_init.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) + (h_error_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((error.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (acc : AddErrorReduceFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (AddErrorReduceFC.inv self_init error k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) error + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ AddErrorReduceFC.step_post self_init error k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.coefficients.length = 16 := + Std.Array.length_eq _ + have h_error_coef_len : error.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_acc_done, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some i = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) `index_mut_usize acc.coefficients k` → `(t, set_back) = (acc.coefs[k], acc.coefs.set k)`. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.coefficients.val[k.val]! with ht_def + have h_idx_t : Aeneas.Std.Array.index_usize acc.coefficients k = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.coefficients k + (by rw [h_coef_len]; exact hk_16) + have h_imt_t : Aeneas.Std.Array.index_mut_usize acc.coefficients k + = .ok (t, acc.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t]; rfl + -- (1a) `t = self_init.coefficients[k]` (via h_acc_undone at j=k). + have h_t_eq : t = self_init.coefficients.val[k.val]! := by + show acc.coefficients.val[k.val]! = self_init.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_self_bnd k.val hk_16 ℓ hℓ + -- (2) `mont_mul(t, 1441#i16)` → `t1`. Pre: |1441| ≤ 1664 ✓; |t| ≤ 32767 ✓. + have h_c1441_bnd : ((1441#i16 : Std.I16).val.natAbs) ≤ 1664 := by decide + obtain ⟨t1, h_t1_eq, h_t1_lift_mont⟩ := + triple_exists_ok_fc + (montgomery_multiply_by_constant_fc t (1441#i16) h_t_bnd h_c1441_bnd) + -- Also pull the legacy per-element fact for the bound and value. + have h_t1_spec := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.montgomery_multiply_by_constant_spec + t (1441#i16) h_c1441_bnd + obtain ⟨t1', h_t1_eq', h_t1_per⟩ := triple_exists_ok_fc h_t1_spec + have h_t1_same : t1 = t1' := by + have := h_t1_eq.symm.trans h_t1_eq' + cases this; rfl + subst h_t1_same + have h_t1_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t1.elements.val[ℓ]!).val.natAbs ≤ 3328 := by + intro ℓ hℓ; exact (h_t1_per ℓ hℓ).1 + have h_t1_modq : ∀ ℓ : Nat, ℓ < 16 → + ((t1.elements.val[ℓ]!).val * (2 ^ 16 : Int)) % 3329 + = ((t.elements.val[ℓ]!).val * (1441#i16 : Std.I16).val) % 3329 := by + intro ℓ hℓ; exact (h_t1_per ℓ hℓ).2 + -- (3) `a = acc.coefficients.set k t1`. + set a : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.coefficients.set k t1 with ha_def + -- (4) `index_mut_usize a k` → `(t2, set_back2) = (a[k], a.set k) = (t1, a.set k)`. + have h_a_len : a.length = 16 := by simp [ha_def, h_coef_len] + have h_a_k : a.val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + have h_idx_t2 : Aeneas.Std.Array.index_usize a k = .ok (a.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a k + (by rw [h_a_len]; exact hk_16) + have h_imt_t2 : Aeneas.Std.Array.index_mut_usize a k = .ok (t1, a.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t2]; rw [h_a_k]; rfl + -- (5) `index_usize error.coefficients k` → `t3 = error.coefs[k]`. + set t3 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + error.coefficients.val[k.val]! with ht3_def + have h_idx_t3 : Aeneas.Std.Array.index_usize error.coefficients k = .ok t3 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq error.coefficients k + (by rw [h_error_coef_len]; exact hk_16) + have h_t3_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t3.elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro ℓ hℓ; exact h_error_bnd k.val hk_16 ℓ hℓ + -- (6) `add t1 t3` → `t4`. Pre: |t1[ℓ] + t3[ℓ]| ≤ 32767. + -- |t1| ≤ 3328, |t3| ≤ 29439, so |t1 + t3| ≤ 3328 + 29439 = 32767 ✓. + have h_add_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((t1.elements.val[ℓ]!).val + (t3.elements.val[ℓ]!).val : Int).natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb_t1 := h_t1_bnd ℓ hℓ + have hb_t3 := h_t3_bnd ℓ hℓ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2] + have h_abs_add : ((t1.elements.val[ℓ]!).val + + (t3.elements.val[ℓ]!).val : Int).natAbs + ≤ ((t1.elements.val[ℓ]!).val : Int).natAbs + + ((t3.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + omega + obtain ⟨t4, h_t4_eq, h_t4_lift⟩ := + triple_exists_ok_fc (add_fc t1 t3 h_add_bnd) + -- Pull legacy per-element value: t4[ℓ].val = t1[ℓ].val + t3[ℓ].val. + have h_t4_spec := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.add_spec t1 t3 h_add_bnd + obtain ⟨t4', h_t4_eq', h_t4_per⟩ := triple_exists_ok_fc h_t4_spec + have h_t4_same : t4 = t4' := by + have := h_t4_eq.symm.trans h_t4_eq' + cases this; rfl + subst h_t4_same + have h_t4_val : ∀ ℓ : Nat, ℓ < 16 → + (t4.elements.val[ℓ]!).val + = (t1.elements.val[ℓ]!).val + (t3.elements.val[ℓ]!).val := by + intro ℓ hℓ; exact (h_t4_per ℓ hℓ).1 + have h_t4_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t4.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ; exact (h_t4_per ℓ hℓ).2 + -- (7) `a1 = a.set k t4`. + set a1 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a.set k t4 with ha1_def + have h_a1_len : a1.length = 16 := by simp [ha1_def, h_a_len] + have h_a1_k : a1.val[k.val]! = t4 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq a k k.val t4 + ⟨rfl, by rw [h_a_len]; exact hk_16⟩ + -- (8) `index_mut_usize a1 k` → `(t5, set_back3) = (t4, a1.set k)`. + have h_idx_t5 : Aeneas.Std.Array.index_usize a1 k = .ok (a1.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a1 k + (by rw [h_a1_len]; exact hk_16) + have h_imt_t5 : Aeneas.Std.Array.index_mut_usize a1 k = .ok (t4, a1.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t5]; rw [h_a1_k]; rfl + -- (9) `barrett_reduce t4` → `t6`. Pre: |t4[ℓ]| ≤ 32767 ✓. + obtain ⟨t6, h_t6_eq, h_t6_post⟩ := + triple_exists_ok_fc (barrett_reduce_fc t4 h_t4_bnd) + obtain ⟨_h_t6_bnd, h_t6_lift⟩ := h_t6_post + -- (10) Compose acc' = `{ coefficients := a1.set k t6 }`. + set a2 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a1.set k t6 with ha2_def + set acc' : AddErrorReduceFC.Acc := { coefficients := a2 } with hacc'_def + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) error + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show (do + let (t', index_mut_back) ← + Aeneas.Std.Array.index_mut_usize acc.coefficients k + let t1' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant + t' (1441#i16) + let (t2', index_mut_back1) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back t1') k + let t3' ← Aeneas.Std.Array.index_usize error.coefficients k + let t4' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.add t2' t3' + let (t5', index_mut_back2) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back1 t4') k + let t6' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce t5' + .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := index_mut_back2 t6' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_imt_t]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t2]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_t3]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t5]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t6_eq] + rfl + apply triple_of_ok_fc h_body + show AddErrorReduceFC.step_post self_init error k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold AddErrorReduceFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (AddErrorReduceFC.inv self_init error s acc').holds + have h_inv_pure : + (∀ j : Nat, j < s.val → + lift_chunk (acc'.coefficients.val[j]!) + = Spec.chunk_add_error_reduce_pure + (lift_chunk (self_init.coefficients.val[j]!)) + (lift_chunk (error.coefficients.val[j]!))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.coefficients.val[j]! = self_init.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · -- (a) j < s.val → FC equation at chunk j. + intro j hj + rw [hs_val] at hj + show lift_chunk ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: chunk j unchanged through all three sets. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set1 : ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) + = (((acc.coefficients.set k t1).set k t4).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + ((acc.coefficients.set k t1).set k t4) k j t6 h_ne + have h_set2 : (((acc.coefficients.set k t1).set k t4).val[j]!) + = ((acc.coefficients.set k t1).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (acc.coefficients.set k t1) k j t4 h_ne + have h_set3 : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set1, h_set2, h_set3] + exact h_acc_done j hj_lt_k + · -- j = k.val: chunk j = t6, need lift_chunk t6 = chunk_add_error_reduce_pure .... + subst hj_eq_k + have h_set_eq : ((((acc.coefficients.set k t1).set k t4).set k t6).val[k.val]!) + = t6 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq + ((acc.coefficients.set k t1).set k t4) k k.val t6 + ⟨rfl, by simp; exact hk_16⟩ + rw [h_set_eq] + -- Goal: lift_chunk t6 = chunk_add_error_reduce_pure + -- (lift_chunk self_init[k]) (lift_chunk error[k]). + rw [h_t6_lift] + -- Now: chunk_barrett_reduce_pure (lift_chunk t4) = + -- chunk_add_error_reduce_pure (lift_chunk self_init[k]) (lift_chunk error[k]). + show Spec.chunk_barrett_reduce_pure (lift_chunk t4) + = Spec.chunk_add_error_reduce_pure + (lift_chunk (self_init.coefficients.val[k.val]!)) + (lift_chunk (error.coefficients.val[k.val]!)) + unfold Spec.chunk_barrett_reduce_pure Spec.chunk_add_error_reduce_pure + apply Subtype.ext + change (List.range 16).map (fun i => + Spec.barrett_pure ((lift_chunk t4).val[i]!)) + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk (self_init.coefficients.val[k.val]!)).val[ℓ]!) + (lift_fe_mont (1441#i16))) + ((lift_chunk (error.coefficients.val[k.val]!)).val[ℓ]!)) + apply List.ext_getElem + · simp + · intro ℓ hℓ1 _hℓ2 + have hℓ : ℓ < 16 := by + have : ℓ < (List.range 16).length := by simpa using hℓ1 + simpa using this + rw [List.getElem_map, List.getElem_range, + List.getElem_map, List.getElem_range] + -- LHS: Spec.barrett_pure ((lift_chunk t4).val[ℓ]!). + -- Step A: ((lift_chunk t4).val[ℓ]!) = lift_fe (t4.elements.val[ℓ]!). + have h_t4_elems_len : t4.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length t4 + have h_lc_t4 : ((lift_chunk t4).val[ℓ]!) = lift_fe (t4.elements.val[ℓ]!) := by + unfold lift_chunk + show ((t4.elements.val.map lift_fe)[ℓ]!) = _ + have h_len : (t4.elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_t4_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos t4.elements.val ℓ (by rw [h_t4_elems_len]; exact hℓ)] + rw [h_lc_t4] + -- Step B: barrett_pure (lift_fe t4[ℓ]) = lift_fe t4[ℓ]. + rw [barrett_pure_lift_fe] + -- Step C: lift_fe t4[ℓ] = add_pure (lift_fe t1[ℓ]) (lift_fe t3[ℓ]). + have hv4 := h_t4_val ℓ hℓ + have h_lift_t4 : + lift_fe (t4.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe (t1.elements.val[ℓ]!)) (lift_fe (t3.elements.val[ℓ]!)) := + lift_fe_add_pure_eq _ _ _ hv4 + rw [h_lift_t4] + -- Step D: lift_fe t1[ℓ] = mul_pure (lift_fe t[ℓ]) (lift_fe_mont 1441). + have h_lift_t1 : + lift_fe (t1.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe (t.elements.val[ℓ]!)) (lift_fe_mont (1441#i16)) := + lift_fe_mont_mul_pure_eq _ _ _ (h_t1_modq ℓ hℓ) + rw [h_lift_t1] + -- Step E: rewrite t and t3 to self_init[k] and error[k] images. + have h_self_elems_len : + (self_init.coefficients.val[k.val]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_err_elems_len : + (error.coefficients.val[k.val]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_lc_self : ((lift_chunk (self_init.coefficients.val[k.val]!)).val[ℓ]!) + = lift_fe ((self_init.coefficients.val[k.val]!).elements.val[ℓ]!) := by + unfold lift_chunk + show (((self_init.coefficients.val[k.val]!).elements.val.map lift_fe)[ℓ]!) = _ + have h_len : + ((self_init.coefficients.val[k.val]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_self_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos (self_init.coefficients.val[k.val]!).elements.val ℓ + (by rw [h_self_elems_len]; exact hℓ)] + have h_lc_err : ((lift_chunk (error.coefficients.val[k.val]!)).val[ℓ]!) + = lift_fe ((error.coefficients.val[k.val]!).elements.val[ℓ]!) := by + unfold lift_chunk + show (((error.coefficients.val[k.val]!).elements.val.map lift_fe)[ℓ]!) = _ + have h_len : + ((error.coefficients.val[k.val]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_err_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos (error.coefficients.val[k.val]!).elements.val ℓ + (by rw [h_err_elems_len]; exact hℓ)] + rw [h_lc_self, h_lc_err] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe (t.elements.val[ℓ]!)) (lift_fe_mont (1441#i16))) + (lift_fe (t3.elements.val[ℓ]!)) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe ((self_init.coefficients.val[k.val]!).elements.val[ℓ]!)) + (lift_fe_mont (1441#i16))) + (lift_fe ((error.coefficients.val[k.val]!).elements.val[ℓ]!)) + rw [show t = self_init.coefficients.val[k.val]! from h_t_eq, + show t3 = error.coefficients.val[k.val]! from ht3_def] + · -- (b) s.val ≤ j < 16 → acc'.coefs[j] = self_init.coefs[j]. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + show ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) + = self_init.coefficients.val[j]! + have h_set1 : ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) + = (((acc.coefficients.set k t1).set k t4).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + ((acc.coefficients.set k t1).set k t4) k j t6 h_ne + have h_set2 : (((acc.coefficients.set k t1).set k t4).val[j]!) + = ((acc.coefficients.set k t1).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (acc.coefficients.set k t1) k j t4 h_ne + have h_set3 : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set1, h_set2, h_set3] + exact h_acc_undone j h_ge' hj_lt + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) error + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_fc h_body + show AddErrorReduceFC.step_post self_init error k (.done acc) + unfold AddErrorReduceFC.step_post + show (AddErrorReduceFC.inv self_init error 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.coefficients.val[j]!) + = Spec.chunk_add_error_reduce_pure + (lift_chunk (self_init.coefficients.val[j]!)) + (lift_chunk (error.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.coefficients.val[j]! = self_init.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L6.4 — `add_error_reduce`: `self · (R/128) + error` then barrett. + Returns `(re, scratch)` tuple; we project on `re`. + + **Preconditions** (load-bearing, beyond the locked True-pre form): + - `h_self_bnd`: per-lane `|self[k][ℓ]| ≤ 32767` (consumed by `mont_mul`'s + legacy precondition; the impl's later `add` then uses `|t1| ≤ 3328` from + mont's output bound). + - `h_error_bnd`: per-lane `|error[k][ℓ]| ≤ 29439` (drives `add`'s overflow + bound: |t1| + |error| ≤ 3328 + 29439 = 32767). -/ +@[spec] +theorem add_error_reduce_fc + (self error : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((self.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) + (h_error_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((error.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + (vectortraitsOperationsInst := portable_ops_inst) self error + ⦃ ⇓ p => ⌜ lift_poly p = Spec.add_error_reduce_pure (lift_poly self) (lift_poly error) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce + -- Resolve `VECTORS_IN_RING_ELEMENT = .ok 16#usize`. + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vre]; simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, self1) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) error iter1 self1) + (β := AddErrorReduceFC.Acc) + self + 0#usize 16#usize + (AddErrorReduceFC.inv self error) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_⟩ + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _; trivial) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (AddErrorReduceFC.inv self error 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.coefficients.val[j]!) + = Spec.chunk_add_error_reduce_pure + (lift_chunk (self.coefficients.val[j]!)) + (lift_chunk (error.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.coefficients.val[j]! = self.coefficients.val[j]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + AddErrorReduceFC.inv] using h_inv_holds + obtain ⟨h_done, _h_undone⟩ := h_inv + -- Build chunks_arr matching the Spec definition, then apply + -- flatten_chunks_eq_lift_poly_fc. + unfold Spec.add_error_reduce_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_add_error_reduce_pure + (Spec.chunk_at (lift_poly self) k) + (Spec.chunk_at (lift_poly error) k))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_add_error_reduce_pure + (Spec.chunk_at (lift_poly self) k) + (Spec.chunk_at (lift_poly error) k)))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc self k hk, chunk_at_lift_poly_fc error k hk] + exact (h_done k hk).symm + have h_final := flatten_chunks_eq_lift_poly_fc r chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Step entailment: per-iteration step lemma. + intro acc k _h_ge h_le hinv + have h_step := + add_error_reduce_step_lemma_fc self error h_self_bnd h_error_bnd acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : AddErrorReduceFC.step_post self error k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [AddErrorReduceFC.step_post] using hP + · have hP : AddErrorReduceFC.step_post self error k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [AddErrorReduceFC.step_post] using hP + +/-! ### L6.5.A — Loop scaffolding for `add_standard_error_reduce_fc`. -/ + +namespace AddStandardErrorReduceFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Step-local accumulator (the mutable `self` poly). -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `add_standard_error_reduce_fc`. + * (a) Chunks `j < k`: FC equation `lift_chunk acc[j] = + chunk_add_standard_error_reduce_pure (lift_chunk self_init[j]) + (lift_chunk error[j])`. + * (b) Chunks `k ≤ j < 16`: `acc[j] = self_init[j]` (unchanged). -/ +def inv + (self_init error : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + lift_chunk (acc.coefficients.val[j]!) + = Spec.chunk_add_standard_error_reduce_pure + (lift_chunk (self_init.coefficients.val[j]!)) + (lift_chunk (error.coefficients.val[j]!))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.coefficients.val[j]! = self_init.coefficients.val[j]!)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (self_init error : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv self_init error iter'.start acc').holds + | .done y => (inv self_init error 16#usize y).holds + +end AddStandardErrorReduceFC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for `add_standard_error_reduce`. Given a + valid loop state `(acc, k)` with `k.val < 16`, applies the + `mont_mul(1353) + add(error[k]) + barrett` chain (via the + `to_standard_domain` wrapper) to chunk `k.val` of `acc`, recording the + FC equation `lift_chunk acc'[k.val] = + chunk_add_standard_error_reduce_pure (lift_chunk self_init[k.val]) + (lift_chunk error[k.val])` while preserving chunks `j ≠ k.val`. -/ +theorem add_standard_error_reduce_step_lemma_fc + (self_init error : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((self_init.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) + (h_error_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((error.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) + (acc : AddStandardErrorReduceFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (AddStandardErrorReduceFC.inv self_init error k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) error + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ AddStandardErrorReduceFC.step_post self_init error k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.coefficients.length = 16 := + Std.Array.length_eq _ + have h_error_coef_len : error.coefficients.length = 16 := + Std.Array.length_eq _ + have h_RR_unfold : + libcrux_iot_ml_kem.vector.traits.MONTGOMERY_R_SQUARED_MOD_FIELD_MODULUS + = (1353#i16 : Std.I16) := by + unfold libcrux_iot_ml_kem.vector.traits.MONTGOMERY_R_SQUARED_MOD_FIELD_MODULUS + rfl + obtain ⟨h_acc_done, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some i = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) `index_mut_usize acc.coefficients k` → `(t, set_back) = (acc.coefs[k], acc.coefs.set k)`. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.coefficients.val[k.val]! with ht_def + have h_idx_t : Aeneas.Std.Array.index_usize acc.coefficients k = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.coefficients k + (by rw [h_coef_len]; exact hk_16) + have h_imt_t : Aeneas.Std.Array.index_mut_usize acc.coefficients k + = .ok (t, acc.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t]; rfl + -- (1a) `t = self_init.coefficients[k]` (via h_acc_undone at j=k). + have h_t_eq : t = self_init.coefficients.val[k.val]! := by + show acc.coefficients.val[k.val]! = self_init.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_self_bnd k.val hk_16 ℓ hℓ + -- (2) `mont_mul(t, 1353#i16)` → `t1`. Pre: |1353| ≤ 1664 ✓; |t| ≤ 32767 ✓. + have h_c1353_bnd : ((1353#i16 : Std.I16).val.natAbs) ≤ 1664 := by decide + obtain ⟨t1, h_t1_eq, h_t1_lift_mont⟩ := + triple_exists_ok_fc + (montgomery_multiply_by_constant_fc t (1353#i16) h_t_bnd h_c1353_bnd) + -- Also pull the legacy per-element fact for the bound and value. + have h_t1_spec := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.montgomery_multiply_by_constant_spec + t (1353#i16) h_c1353_bnd + obtain ⟨t1', h_t1_eq', h_t1_per⟩ := triple_exists_ok_fc h_t1_spec + have h_t1_same : t1 = t1' := by + have := h_t1_eq.symm.trans h_t1_eq' + cases this; rfl + subst h_t1_same + have h_t1_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t1.elements.val[ℓ]!).val.natAbs ≤ 3328 := by + intro ℓ hℓ; exact (h_t1_per ℓ hℓ).1 + have h_t1_modq : ∀ ℓ : Nat, ℓ < 16 → + ((t1.elements.val[ℓ]!).val * (2 ^ 16 : Int)) % 3329 + = ((t.elements.val[ℓ]!).val * (1353#i16 : Std.I16).val) % 3329 := by + intro ℓ hℓ; exact (h_t1_per ℓ hℓ).2 + -- (3) `a = acc.coefficients.set k t1`. + set a : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.coefficients.set k t1 with ha_def + -- (4) `index_mut_usize a k` → `(t2, set_back2) = (a[k], a.set k) = (t1, a.set k)`. + have h_a_len : a.length = 16 := by simp [ha_def, h_coef_len] + have h_a_k : a.val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + have h_idx_t2 : Aeneas.Std.Array.index_usize a k = .ok (a.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a k + (by rw [h_a_len]; exact hk_16) + have h_imt_t2 : Aeneas.Std.Array.index_mut_usize a k = .ok (t1, a.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t2]; rw [h_a_k]; rfl + -- (5) `index_usize error.coefficients k` → `t3 = error.coefs[k]`. + set t3 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + error.coefficients.val[k.val]! with ht3_def + have h_idx_t3 : Aeneas.Std.Array.index_usize error.coefficients k = .ok t3 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq error.coefficients k + (by rw [h_error_coef_len]; exact hk_16) + have h_t3_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t3.elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro ℓ hℓ; exact h_error_bnd k.val hk_16 ℓ hℓ + -- (6) `add t1 t3` → `t4`. Pre: |t1[ℓ] + t3[ℓ]| ≤ 32767. + -- |t1| ≤ 3328, |t3| ≤ 29439, so |t1 + t3| ≤ 3328 + 29439 = 32767 ✓. + have h_add_bnd : ∀ ℓ : Nat, ℓ < 16 → + ((t1.elements.val[ℓ]!).val + (t3.elements.val[ℓ]!).val : Int).natAbs ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb_t1 := h_t1_bnd ℓ hℓ + have hb_t3 := h_t3_bnd ℓ hℓ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2] + have h_abs_add : ((t1.elements.val[ℓ]!).val + + (t3.elements.val[ℓ]!).val : Int).natAbs + ≤ ((t1.elements.val[ℓ]!).val : Int).natAbs + + ((t3.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + omega + obtain ⟨t4, h_t4_eq, h_t4_lift⟩ := + triple_exists_ok_fc (add_fc t1 t3 h_add_bnd) + -- Pull legacy per-element value: t4[ℓ].val = t1[ℓ].val + t3[ℓ].val. + have h_t4_spec := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.add_spec t1 t3 h_add_bnd + obtain ⟨t4', h_t4_eq', h_t4_per⟩ := triple_exists_ok_fc h_t4_spec + have h_t4_same : t4 = t4' := by + have := h_t4_eq.symm.trans h_t4_eq' + cases this; rfl + subst h_t4_same + have h_t4_val : ∀ ℓ : Nat, ℓ < 16 → + (t4.elements.val[ℓ]!).val + = (t1.elements.val[ℓ]!).val + (t3.elements.val[ℓ]!).val := by + intro ℓ hℓ; exact (h_t4_per ℓ hℓ).1 + have h_t4_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t4.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ; exact (h_t4_per ℓ hℓ).2 + -- (7) `a1 = a.set k t4`. + set a1 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a.set k t4 with ha1_def + have h_a1_len : a1.length = 16 := by simp [ha1_def, h_a_len] + have h_a1_k : a1.val[k.val]! = t4 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq a k k.val t4 + ⟨rfl, by rw [h_a_len]; exact hk_16⟩ + -- (8) `index_mut_usize a1 k` → `(t5, set_back3) = (t4, a1.set k)`. + have h_idx_t5 : Aeneas.Std.Array.index_usize a1 k = .ok (a1.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a1 k + (by rw [h_a1_len]; exact hk_16) + have h_imt_t5 : Aeneas.Std.Array.index_mut_usize a1 k = .ok (t4, a1.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t5]; rw [h_a1_k]; rfl + -- (9) `barrett_reduce t4` → `t6`. Pre: |t4[ℓ]| ≤ 32767 ✓. + obtain ⟨t6, h_t6_eq, h_t6_post⟩ := + triple_exists_ok_fc (barrett_reduce_fc t4 h_t4_bnd) + obtain ⟨_h_t6_bnd, h_t6_lift⟩ := h_t6_post + -- (10) Compose acc' = `{ coefficients := a1.set k t6 }`. + set a2 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a1.set k t6 with ha2_def + set acc' : AddStandardErrorReduceFC.Acc := { coefficients := a2 } with hacc'_def + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) error + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce_loop.body + simp only [libcrux_iot_ml_kem.vector.traits.to_standard_domain, h_RR_unfold] + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show (do + let (t', index_mut_back) ← + Aeneas.Std.Array.index_mut_usize acc.coefficients k + let t1' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant + t' (1353#i16) + let (t2', index_mut_back1) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back t1') k + let t3' ← Aeneas.Std.Array.index_usize error.coefficients k + let t4' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.add t2' t3' + let (t5', index_mut_back2) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back1 t4') k + let t6' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce t5' + .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := index_mut_back2 t6' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_imt_t]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t2]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_t3]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t5]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t6_eq] + rfl + apply triple_of_ok_fc h_body + show AddStandardErrorReduceFC.step_post self_init error k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold AddStandardErrorReduceFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (AddStandardErrorReduceFC.inv self_init error s acc').holds + have h_inv_pure : + (∀ j : Nat, j < s.val → + lift_chunk (acc'.coefficients.val[j]!) + = Spec.chunk_add_standard_error_reduce_pure + (lift_chunk (self_init.coefficients.val[j]!)) + (lift_chunk (error.coefficients.val[j]!))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.coefficients.val[j]! = self_init.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · -- (a) j < s.val → FC equation at chunk j. + intro j hj + rw [hs_val] at hj + show lift_chunk ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: chunk j unchanged through all three sets. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set1 : ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) + = (((acc.coefficients.set k t1).set k t4).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + ((acc.coefficients.set k t1).set k t4) k j t6 h_ne + have h_set2 : (((acc.coefficients.set k t1).set k t4).val[j]!) + = ((acc.coefficients.set k t1).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (acc.coefficients.set k t1) k j t4 h_ne + have h_set3 : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set1, h_set2, h_set3] + exact h_acc_done j hj_lt_k + · -- j = k.val: chunk j = t6, need lift_chunk t6 = chunk_add_standard_error_reduce_pure .... + subst hj_eq_k + have h_set_eq : ((((acc.coefficients.set k t1).set k t4).set k t6).val[k.val]!) + = t6 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq + ((acc.coefficients.set k t1).set k t4) k k.val t6 + ⟨rfl, by simp; exact hk_16⟩ + rw [h_set_eq] + -- Goal: lift_chunk t6 = chunk_add_standard_error_reduce_pure + -- (lift_chunk self_init[k]) (lift_chunk error[k]). + rw [h_t6_lift] + -- Now: chunk_barrett_reduce_pure (lift_chunk t4) = + -- chunk_add_standard_error_reduce_pure (lift_chunk self_init[k]) (lift_chunk error[k]). + show Spec.chunk_barrett_reduce_pure (lift_chunk t4) + = Spec.chunk_add_standard_error_reduce_pure + (lift_chunk (self_init.coefficients.val[k.val]!)) + (lift_chunk (error.coefficients.val[k.val]!)) + unfold Spec.chunk_barrett_reduce_pure Spec.chunk_add_standard_error_reduce_pure + apply Subtype.ext + change (List.range 16).map (fun i => + Spec.barrett_pure ((lift_chunk t4).val[i]!)) + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk (self_init.coefficients.val[k.val]!)).val[ℓ]!) + (lift_fe_mont (1353#i16))) + ((lift_chunk (error.coefficients.val[k.val]!)).val[ℓ]!)) + apply List.ext_getElem + · simp + · intro ℓ hℓ1 _hℓ2 + have hℓ : ℓ < 16 := by + have : ℓ < (List.range 16).length := by simpa using hℓ1 + simpa using this + rw [List.getElem_map, List.getElem_range, + List.getElem_map, List.getElem_range] + -- LHS: Spec.barrett_pure ((lift_chunk t4).val[ℓ]!). + -- Step A: ((lift_chunk t4).val[ℓ]!) = lift_fe (t4.elements.val[ℓ]!). + have h_t4_elems_len : t4.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length t4 + have h_lc_t4 : ((lift_chunk t4).val[ℓ]!) = lift_fe (t4.elements.val[ℓ]!) := by + unfold lift_chunk + show ((t4.elements.val.map lift_fe)[ℓ]!) = _ + have h_len : (t4.elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_t4_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos t4.elements.val ℓ (by rw [h_t4_elems_len]; exact hℓ)] + rw [h_lc_t4] + -- Step B: barrett_pure (lift_fe t4[ℓ]) = lift_fe t4[ℓ]. + rw [barrett_pure_lift_fe] + -- Step C: lift_fe t4[ℓ] = add_pure (lift_fe t1[ℓ]) (lift_fe t3[ℓ]). + have hv4 := h_t4_val ℓ hℓ + have h_lift_t4 : + lift_fe (t4.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe (t1.elements.val[ℓ]!)) (lift_fe (t3.elements.val[ℓ]!)) := + lift_fe_add_pure_eq _ _ _ hv4 + rw [h_lift_t4] + -- Step D: lift_fe t1[ℓ] = mul_pure (lift_fe t[ℓ]) (lift_fe_mont 1353). + have h_lift_t1 : + lift_fe (t1.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe (t.elements.val[ℓ]!)) (lift_fe_mont (1353#i16)) := + lift_fe_mont_mul_pure_eq _ _ _ (h_t1_modq ℓ hℓ) + rw [h_lift_t1] + -- Step E: rewrite t and t3 to self_init[k] and error[k] images. + have h_self_elems_len : + (self_init.coefficients.val[k.val]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_err_elems_len : + (error.coefficients.val[k.val]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_lc_self : ((lift_chunk (self_init.coefficients.val[k.val]!)).val[ℓ]!) + = lift_fe ((self_init.coefficients.val[k.val]!).elements.val[ℓ]!) := by + unfold lift_chunk + show (((self_init.coefficients.val[k.val]!).elements.val.map lift_fe)[ℓ]!) = _ + have h_len : + ((self_init.coefficients.val[k.val]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_self_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos (self_init.coefficients.val[k.val]!).elements.val ℓ + (by rw [h_self_elems_len]; exact hℓ)] + have h_lc_err : ((lift_chunk (error.coefficients.val[k.val]!)).val[ℓ]!) + = lift_fe ((error.coefficients.val[k.val]!).elements.val[ℓ]!) := by + unfold lift_chunk + show (((error.coefficients.val[k.val]!).elements.val.map lift_fe)[ℓ]!) = _ + have h_len : + ((error.coefficients.val[k.val]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_err_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos (error.coefficients.val[k.val]!).elements.val ℓ + (by rw [h_err_elems_len]; exact hℓ)] + rw [h_lc_self, h_lc_err] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe (t.elements.val[ℓ]!)) (lift_fe_mont (1353#i16))) + (lift_fe (t3.elements.val[ℓ]!)) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe ((self_init.coefficients.val[k.val]!).elements.val[ℓ]!)) + (lift_fe_mont (1353#i16))) + (lift_fe ((error.coefficients.val[k.val]!).elements.val[ℓ]!)) + rw [show t = self_init.coefficients.val[k.val]! from h_t_eq, + show t3 = error.coefficients.val[k.val]! from ht3_def] + · -- (b) s.val ≤ j < 16 → acc'.coefs[j] = self_init.coefs[j]. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + show ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) + = self_init.coefficients.val[j]! + have h_set1 : ((((acc.coefficients.set k t1).set k t4).set k t6).val[j]!) + = (((acc.coefficients.set k t1).set k t4).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + ((acc.coefficients.set k t1).set k t4) k j t6 h_ne + have h_set2 : (((acc.coefficients.set k t1).set k t4).val[j]!) + = ((acc.coefficients.set k t1).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (acc.coefficients.set k t1) k j t4 h_ne + have h_set3 : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set1, h_set2, h_set3] + exact h_acc_undone j h_ge' hj_lt + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) error + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_fc h_body + show AddStandardErrorReduceFC.step_post self_init error k (.done acc) + unfold AddStandardErrorReduceFC.step_post + show (AddStandardErrorReduceFC.inv self_init error 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.coefficients.val[j]!) + = Spec.chunk_add_standard_error_reduce_pure + (lift_chunk (self_init.coefficients.val[j]!)) + (lift_chunk (error.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.coefficients.val[j]! = self_init.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L6.5 — `add_standard_error_reduce`: `self · R² + error` then barrett. + Used to take an inverse-NTT result back to "standard domain". + + **Preconditions** (load-bearing, beyond the locked True-pre form): + - `h_self_bnd`: per-lane `|self[k][ℓ]| ≤ 32767` (consumed by `mont_mul`'s + legacy precondition; the impl's later `add` then uses `|t1| ≤ 3328` from + mont's output bound). + - `h_error_bnd`: per-lane `|error[k][ℓ]| ≤ 29439` (drives `add`'s overflow + bound: |t1| + |error| ≤ 3328 + 29439 = 32767). -/ +@[spec] +theorem add_standard_error_reduce_fc + (self error : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_self_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((self.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) + (h_error_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((error.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + (vectortraitsOperationsInst := portable_ops_inst) self error + ⦃ ⇓ p => ⌜ lift_poly p + = Spec.add_standard_error_reduce_pure (lift_poly self) (lift_poly error) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce + -- Resolve `VECTORS_IN_RING_ELEMENT = .ok 16#usize`. + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vre]; simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, self1) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_standard_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) error iter1 self1) + (β := AddStandardErrorReduceFC.Acc) + self + 0#usize 16#usize + (AddStandardErrorReduceFC.inv self error) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_⟩ + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _; trivial) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (AddStandardErrorReduceFC.inv self error 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.coefficients.val[j]!) + = Spec.chunk_add_standard_error_reduce_pure + (lift_chunk (self.coefficients.val[j]!)) + (lift_chunk (error.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.coefficients.val[j]! = self.coefficients.val[j]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + AddStandardErrorReduceFC.inv] using h_inv_holds + obtain ⟨h_done, _h_undone⟩ := h_inv + -- Build chunks_arr matching the Spec definition, then apply + -- flatten_chunks_eq_lift_poly_fc. + unfold Spec.add_standard_error_reduce_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly self) k) + (Spec.chunk_at (lift_poly error) k))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at (lift_poly self) k) + (Spec.chunk_at (lift_poly error) k)))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc self k hk, chunk_at_lift_poly_fc error k hk] + exact (h_done k hk).symm + have h_final := flatten_chunks_eq_lift_poly_fc r chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Step entailment: per-iteration step lemma. + intro acc k _h_ge h_le hinv + have h_step := + add_standard_error_reduce_step_lemma_fc self error h_self_bnd h_error_bnd acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : AddStandardErrorReduceFC.step_post self error k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [AddStandardErrorReduceFC.step_post] using hP + · have hP : AddStandardErrorReduceFC.step_post self error k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [AddStandardErrorReduceFC.step_post] using hP + +/-! ### L6.6.A — Loop scaffolding for `add_message_error_reduce_fc`. + + Unlike L6.4/L6.5 (single-poly Acc), this loop carries a 2-tuple + `(result_acc, scratch)` because the impl reads/writes both per + iteration. The FC equation lives entirely on `result_acc`; `scratch` + is unconstrained at iteration boundaries (`scratch_15 = self[15] + + message[15]` at exit, but the FC theorem only projects `p.1`). -/ + +namespace AddMessageErrorReduceFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Step-local accumulator: `(result, scratch)`. -/ +abbrev Acc := + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `add_message_error_reduce_fc`. + * (a) Chunks `j < k`: FC equation on `acc.1` against + `chunk_add_message_error_reduce_pure (self_init[j]) (message_init[j]) (result_init[j])`. + * (b) Chunks `k ≤ j < 16`: `acc.1[j] = result_init[j]` (unchanged). + The `scratch` component `acc.2` is unconstrained. -/ +def inv + (self_init message_init result_init : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + lift_chunk (acc.1.coefficients.val[j]!) + = Spec.chunk_add_message_error_reduce_pure + (lift_chunk (self_init.coefficients.val[j]!)) + (lift_chunk (message_init.coefficients.val[j]!)) + (lift_chunk (result_init.coefficients.val[j]!))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.1.coefficients.val[j]! = result_init.coefficients.val[j]!)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (self_init message_init result_init : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv self_init message_init result_init iter'.start acc').holds + | .done y => (inv self_init message_init result_init 16#usize y).holds + +end AddMessageErrorReduceFC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for `add_message_error_reduce`. Given a + valid loop state `((result_acc, _scratch_acc), k)` with `k.val < 16`, + applies the `mont_mul(1441) + add(self+message) + barrett` chain to + chunk `k.val` of `result_acc`, recording the FC equation + `lift_chunk acc'.1[k.val] = + chunk_add_message_error_reduce_pure + (lift_chunk self_init[k.val]) + (lift_chunk message_init[k.val]) + (lift_chunk result_init[k.val])` + while preserving chunks `j ≠ k.val`. The scratch slot is overwritten + each iteration with `self[k] + message[k]` and remains unconstrained + by the invariant. -/ +theorem add_message_error_reduce_step_lemma_fc + (self_init message_init result_init : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_result_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((result_init.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) + (h_sum_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + (((self_init.coefficients.val[chunk]!).elements.val[ℓ]!).val + + ((message_init.coefficients.val[chunk]!).elements.val[ℓ]!).val + : Int).natAbs ≤ 29439) + (acc : AddMessageErrorReduceFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (AddMessageErrorReduceFC.inv self_init message_init result_init k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) self_init message_init + { start := k, «end» := 16#usize } acc.1 acc.2 + ⦃ ⇓ r => ⌜ AddMessageErrorReduceFC.step_post self_init message_init result_init k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.1.coefficients.length = 16 := + Std.Array.length_eq _ + have h_self_coef_len : self_init.coefficients.length = 16 := + Std.Array.length_eq _ + have h_msg_coef_len : message_init.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_acc_done, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some i = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) `index_mut_usize acc.1.coefficients k` → `(t, set_back) = + -- (acc.1.coefs[k], acc.1.coefs.set k)`. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.1.coefficients.val[k.val]! with ht_def + have h_idx_t : Aeneas.Std.Array.index_usize acc.1.coefficients k = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.1.coefficients k + (by rw [h_coef_len]; exact hk_16) + have h_imt_t : Aeneas.Std.Array.index_mut_usize acc.1.coefficients k + = .ok (t, acc.1.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t]; rfl + -- (1a) `t = result_init.coefficients[k]` (via h_acc_undone at j=k). + have h_t_eq : t = result_init.coefficients.val[k.val]! := by + show acc.1.coefficients.val[k.val]! = result_init.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_result_bnd k.val hk_16 ℓ hℓ + -- (2) `mont_mul(t, 1441#i16)` → `t1`. Pre: |1441| ≤ 1664 ✓; |t| ≤ 32767 ✓. + have h_c1441_bnd : ((1441#i16 : Std.I16).val.natAbs) ≤ 1664 := by decide + obtain ⟨t1, h_t1_eq, h_t1_lift_mont⟩ := + triple_exists_ok_fc + (montgomery_multiply_by_constant_fc t (1441#i16) h_t_bnd h_c1441_bnd) + have h_t1_spec := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.montgomery_multiply_by_constant_spec + t (1441#i16) h_c1441_bnd + obtain ⟨t1', h_t1_eq', h_t1_per⟩ := triple_exists_ok_fc h_t1_spec + have h_t1_same : t1 = t1' := by + have := h_t1_eq.symm.trans h_t1_eq' + cases this; rfl + subst h_t1_same + have h_t1_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t1.elements.val[ℓ]!).val.natAbs ≤ 3328 := by + intro ℓ hℓ; exact (h_t1_per ℓ hℓ).1 + have h_t1_modq : ∀ ℓ : Nat, ℓ < 16 → + ((t1.elements.val[ℓ]!).val * (2 ^ 16 : Int)) % 3329 + = ((t.elements.val[ℓ]!).val * (1441#i16 : Std.I16).val) % 3329 := by + intro ℓ hℓ; exact (h_t1_per ℓ hℓ).2 + -- (3) `scratch1 = self_init.coefs[k]`. + set scratch1 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + self_init.coefficients.val[k.val]! with hscratch1_def + have h_idx_scratch1 : Aeneas.Std.Array.index_usize self_init.coefficients k + = .ok scratch1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq self_init.coefficients k + (by rw [h_self_coef_len]; exact hk_16) + -- (4) `t2 = message_init.coefs[k]`. + set t2 : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + message_init.coefficients.val[k.val]! with ht2_def + have h_idx_t2 : Aeneas.Std.Array.index_usize message_init.coefficients k = .ok t2 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq message_init.coefficients k + (by rw [h_msg_coef_len]; exact hk_16) + -- (5) `add scratch1 t2 = scratch2`. Pre: |scratch1 + t2| ≤ 32767; + -- from `h_sum_bnd` ≤ 29439 ≤ 32767. + have h_scratch2_pre : ∀ ℓ : Nat, ℓ < 16 → + ((scratch1.elements.val[ℓ]!).val + (t2.elements.val[ℓ]!).val : Int).natAbs + ≤ 2^15 - 1 := by + intro ℓ hℓ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2] + have h_sum := h_sum_bnd k.val hk_16 ℓ hℓ + show (((self_init.coefficients.val[k.val]!).elements.val[ℓ]!).val + + ((message_init.coefficients.val[k.val]!).elements.val[ℓ]!).val + : Int).natAbs ≤ 32767 + omega + obtain ⟨scratch2, h_scratch2_eq, h_scratch2_lift⟩ := + triple_exists_ok_fc (add_fc scratch1 t2 h_scratch2_pre) + have h_scratch2_spec := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.add_spec scratch1 t2 h_scratch2_pre + obtain ⟨scratch2', h_scratch2_eq', h_scratch2_per⟩ := triple_exists_ok_fc h_scratch2_spec + have h_scratch2_same : scratch2 = scratch2' := by + have := h_scratch2_eq.symm.trans h_scratch2_eq' + cases this; rfl + subst h_scratch2_same + have h_scratch2_val : ∀ ℓ : Nat, ℓ < 16 → + (scratch2.elements.val[ℓ]!).val + = (scratch1.elements.val[ℓ]!).val + (t2.elements.val[ℓ]!).val := by + intro ℓ hℓ; exact (h_scratch2_per ℓ hℓ).1 + have h_scratch2_bnd : ∀ ℓ : Nat, ℓ < 16 → + (scratch2.elements.val[ℓ]!).val.natAbs ≤ 29439 := by + intro ℓ hℓ + rw [h_scratch2_val ℓ hℓ] + exact h_sum_bnd k.val hk_16 ℓ hℓ + -- (6) `a = acc.1.coefficients.set k t1`. + set a : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.1.coefficients.set k t1 with ha_def + have h_a_len : a.length = 16 := by simp [ha_def, h_coef_len] + have h_a_k : a.val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.1.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + -- (7) `index_mut_usize a k` → `(t3, _) = (t1, a.set k)`. + have h_idx_t3 : Aeneas.Std.Array.index_usize a k = .ok (a.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a k + (by rw [h_a_len]; exact hk_16) + have h_imt_t3 : Aeneas.Std.Array.index_mut_usize a k = .ok (t1, a.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t3]; rw [h_a_k]; rfl + -- (8) `add t1 scratch2 = t4`. Pre: |t1 + scratch2| ≤ 32767; + -- |t1| ≤ 3328, |scratch2| ≤ 29439 ⇒ |t1 + scratch2| ≤ 32767 ✓. + have h_t4_pre : ∀ ℓ : Nat, ℓ < 16 → + ((t1.elements.val[ℓ]!).val + (scratch2.elements.val[ℓ]!).val : Int).natAbs + ≤ 2^15 - 1 := by + intro ℓ hℓ + have hb_t1 := h_t1_bnd ℓ hℓ + have hb_s2 := h_scratch2_bnd ℓ hℓ + have h_p2 : (2 : Nat)^15 - 1 = 32767 := by decide + rw [h_p2] + have h_abs_add : ((t1.elements.val[ℓ]!).val + + (scratch2.elements.val[ℓ]!).val : Int).natAbs + ≤ ((t1.elements.val[ℓ]!).val : Int).natAbs + + ((scratch2.elements.val[ℓ]!).val : Int).natAbs := + Int.natAbs_add_le _ _ + omega + obtain ⟨t4, h_t4_eq, h_t4_lift⟩ := + triple_exists_ok_fc (add_fc t1 scratch2 h_t4_pre) + have h_t4_spec := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.add_spec t1 scratch2 h_t4_pre + obtain ⟨t4', h_t4_eq', h_t4_per⟩ := triple_exists_ok_fc h_t4_spec + have h_t4_same : t4 = t4' := by + have := h_t4_eq.symm.trans h_t4_eq' + cases this; rfl + subst h_t4_same + have h_t4_val : ∀ ℓ : Nat, ℓ < 16 → + (t4.elements.val[ℓ]!).val + = (t1.elements.val[ℓ]!).val + (scratch2.elements.val[ℓ]!).val := by + intro ℓ hℓ; exact (h_t4_per ℓ hℓ).1 + have h_t4_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t4.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ; exact (h_t4_per ℓ hℓ).2 + -- (9) `a1 = a.set k t4`. + set a1 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a.set k t4 with ha1_def + have h_a1_len : a1.length = 16 := by simp [ha1_def, h_a_len] + have h_a1_k : a1.val[k.val]! = t4 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq a k k.val t4 + ⟨rfl, by rw [h_a_len]; exact hk_16⟩ + -- (10) `index_mut_usize a1 k` → `(t5, _) = (t4, a1.set k)`. + have h_idx_t5 : Aeneas.Std.Array.index_usize a1 k = .ok (a1.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq a1 k + (by rw [h_a1_len]; exact hk_16) + have h_imt_t5 : Aeneas.Std.Array.index_mut_usize a1 k = .ok (t4, a1.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t5]; rw [h_a1_k]; rfl + -- (11) `barrett_reduce t4 = t6`. Pre: |t4[ℓ]| ≤ 32767 ✓. + obtain ⟨t6, h_t6_eq, h_t6_post⟩ := + triple_exists_ok_fc (barrett_reduce_fc t4 h_t4_bnd) + obtain ⟨_h_t6_bnd, h_t6_lift⟩ := h_t6_post + -- (12) Compose acc'.1 = `{ coefficients := a1.set k t6 }`, acc'.2 = scratch2. + set a2 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + a1.set k t6 with ha2_def + set acc1' : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { coefficients := a2 } with hacc1'_def + set acc' : AddMessageErrorReduceFC.Acc := (acc1', scratch2) with hacc'_def + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) self_init message_init + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show (do + let (t', index_mut_back) ← + Aeneas.Std.Array.index_mut_usize acc.1.coefficients k + let t1' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant + t' (1441#i16) + let scratch1' ← Aeneas.Std.Array.index_usize self_init.coefficients k + let t2' ← Aeneas.Std.Array.index_usize message_init.coefficients k + let scratch2' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.add scratch1' t2' + let (t3', index_mut_back1) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back t1') k + let t4' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.add t3' scratch2' + let (t5', index_mut_back2) ← + Aeneas.Std.Array.index_mut_usize (index_mut_back1 t4') k + let t6' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce t5' + .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := index_mut_back2 t6' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector), + scratch2'))) + = _ + rw [h_imt_t]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_scratch1]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_t2]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_scratch2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t3]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t5]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t6_eq] + rfl + apply triple_of_ok_fc h_body + show AddMessageErrorReduceFC.step_post self_init message_init result_init k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold AddMessageErrorReduceFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (AddMessageErrorReduceFC.inv self_init message_init result_init s acc').holds + have h_inv_pure : + (∀ j : Nat, j < s.val → + lift_chunk (acc'.1.coefficients.val[j]!) + = Spec.chunk_add_message_error_reduce_pure + (lift_chunk (self_init.coefficients.val[j]!)) + (lift_chunk (message_init.coefficients.val[j]!)) + (lift_chunk (result_init.coefficients.val[j]!))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.1.coefficients.val[j]! = result_init.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · -- (a) j < s.val → FC equation at chunk j. + intro j hj + rw [hs_val] at hj + show lift_chunk ((((acc.1.coefficients.set k t1).set k t4).set k t6).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: chunk j unchanged through all three sets. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set1 : ((((acc.1.coefficients.set k t1).set k t4).set k t6).val[j]!) + = (((acc.1.coefficients.set k t1).set k t4).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + ((acc.1.coefficients.set k t1).set k t4) k j t6 h_ne + have h_set2 : (((acc.1.coefficients.set k t1).set k t4).val[j]!) + = ((acc.1.coefficients.set k t1).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (acc.1.coefficients.set k t1) k j t4 h_ne + have h_set3 : ((acc.1.coefficients.set k t1).val[j]!) + = acc.1.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients k j t1 h_ne + rw [h_set1, h_set2, h_set3] + exact h_acc_done j hj_lt_k + · -- j = k.val: chunk j = t6. Need: + -- lift_chunk t6 = chunk_add_message_error_reduce_pure + -- (lift_chunk self_init[k]) (lift_chunk message_init[k]) + -- (lift_chunk result_init[k]). + subst hj_eq_k + have h_set_eq : ((((acc.1.coefficients.set k t1).set k t4).set k t6).val[k.val]!) + = t6 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq + ((acc.1.coefficients.set k t1).set k t4) k k.val t6 + ⟨rfl, by simp; exact hk_16⟩ + rw [h_set_eq] + rw [h_t6_lift] + show Spec.chunk_barrett_reduce_pure (lift_chunk t4) + = Spec.chunk_add_message_error_reduce_pure + (lift_chunk (self_init.coefficients.val[k.val]!)) + (lift_chunk (message_init.coefficients.val[k.val]!)) + (lift_chunk (result_init.coefficients.val[k.val]!)) + unfold Spec.chunk_barrett_reduce_pure Spec.chunk_add_message_error_reduce_pure + apply Subtype.ext + change (List.range 16).map (fun i => + Spec.barrett_pure ((lift_chunk t4).val[i]!)) + = (List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((lift_chunk (result_init.coefficients.val[k.val]!)).val[ℓ]!) + (lift_fe_mont (1441#i16))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((lift_chunk (self_init.coefficients.val[k.val]!)).val[ℓ]!) + ((lift_chunk (message_init.coefficients.val[k.val]!)).val[ℓ]!))) + apply List.ext_getElem + · simp + · intro ℓ hℓ1 _hℓ2 + have hℓ : ℓ < 16 := by + have : ℓ < (List.range 16).length := by simpa using hℓ1 + simpa using this + rw [List.getElem_map, List.getElem_range, + List.getElem_map, List.getElem_range] + -- LHS: Spec.barrett_pure ((lift_chunk t4).val[ℓ]!). + -- Step A: ((lift_chunk t4).val[ℓ]!) = lift_fe (t4.elements.val[ℓ]!). + have h_t4_elems_len : t4.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length t4 + have h_lc_t4 : ((lift_chunk t4).val[ℓ]!) = lift_fe (t4.elements.val[ℓ]!) := by + unfold lift_chunk + show ((t4.elements.val.map lift_fe)[ℓ]!) = _ + have h_len : (t4.elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_t4_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos t4.elements.val ℓ (by rw [h_t4_elems_len]; exact hℓ)] + rw [h_lc_t4] + -- Step B: barrett_pure (lift_fe t4[ℓ]) = lift_fe t4[ℓ]. + rw [barrett_pure_lift_fe] + -- Step C: lift_fe t4[ℓ] = add_pure (lift_fe t1[ℓ]) (lift_fe scratch2[ℓ]). + have hv4 := h_t4_val ℓ hℓ + have h_lift_t4 : + lift_fe (t4.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe (t1.elements.val[ℓ]!)) + (lift_fe (scratch2.elements.val[ℓ]!)) := + lift_fe_add_pure_eq _ _ _ hv4 + rw [h_lift_t4] + -- Step D: lift_fe t1[ℓ] = mul_pure (lift_fe t[ℓ]) (lift_fe_mont 1441). + have h_lift_t1 : + lift_fe (t1.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe (t.elements.val[ℓ]!)) (lift_fe_mont (1441#i16)) := + lift_fe_mont_mul_pure_eq _ _ _ (h_t1_modq ℓ hℓ) + rw [h_lift_t1] + -- Step E: lift_fe scratch2[ℓ] = add_pure (lift_fe scratch1[ℓ]) (lift_fe t2[ℓ]). + have hv_s2 := h_scratch2_val ℓ hℓ + have h_lift_s2 : + lift_fe (scratch2.elements.val[ℓ]!) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe (scratch1.elements.val[ℓ]!)) + (lift_fe (t2.elements.val[ℓ]!)) := + lift_fe_add_pure_eq _ _ _ hv_s2 + rw [h_lift_s2] + -- Step F: rewrite t, scratch1, t2 to result_init[k], self_init[k], + -- message_init[k] images via lift_chunk projection. + have h_result_elems_len : + (result_init.coefficients.val[k.val]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_self_elems_len : + (self_init.coefficients.val[k.val]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_msg_elems_len : + (message_init.coefficients.val[k.val]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_lc_result : ((lift_chunk (result_init.coefficients.val[k.val]!)).val[ℓ]!) + = lift_fe ((result_init.coefficients.val[k.val]!).elements.val[ℓ]!) := by + unfold lift_chunk + show (((result_init.coefficients.val[k.val]!).elements.val.map lift_fe)[ℓ]!) = _ + have h_len : + ((result_init.coefficients.val[k.val]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_result_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos (result_init.coefficients.val[k.val]!).elements.val ℓ + (by rw [h_result_elems_len]; exact hℓ)] + have h_lc_self : ((lift_chunk (self_init.coefficients.val[k.val]!)).val[ℓ]!) + = lift_fe ((self_init.coefficients.val[k.val]!).elements.val[ℓ]!) := by + unfold lift_chunk + show (((self_init.coefficients.val[k.val]!).elements.val.map lift_fe)[ℓ]!) = _ + have h_len : + ((self_init.coefficients.val[k.val]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_self_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos (self_init.coefficients.val[k.val]!).elements.val ℓ + (by rw [h_self_elems_len]; exact hℓ)] + have h_lc_msg : ((lift_chunk (message_init.coefficients.val[k.val]!)).val[ℓ]!) + = lift_fe ((message_init.coefficients.val[k.val]!).elements.val[ℓ]!) := by + unfold lift_chunk + show (((message_init.coefficients.val[k.val]!).elements.val.map lift_fe)[ℓ]!) = _ + have h_len : + ((message_init.coefficients.val[k.val]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_msg_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + rw [getElem!_pos (message_init.coefficients.val[k.val]!).elements.val ℓ + (by rw [h_msg_elems_len]; exact hℓ)] + rw [h_lc_result, h_lc_self, h_lc_msg] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe (t.elements.val[ℓ]!)) (lift_fe_mont (1441#i16))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe (scratch1.elements.val[ℓ]!)) + (lift_fe (t2.elements.val[ℓ]!))) + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe ((result_init.coefficients.val[k.val]!).elements.val[ℓ]!)) + (lift_fe_mont (1441#i16))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe ((self_init.coefficients.val[k.val]!).elements.val[ℓ]!)) + (lift_fe ((message_init.coefficients.val[k.val]!).elements.val[ℓ]!))) + rw [show t = result_init.coefficients.val[k.val]! from h_t_eq, + show scratch1 = self_init.coefficients.val[k.val]! from hscratch1_def, + show t2 = message_init.coefficients.val[k.val]! from ht2_def] + · -- (b) s.val ≤ j < 16 → acc'.1.coefs[j] = result_init.coefs[j]. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + show ((((acc.1.coefficients.set k t1).set k t4).set k t6).val[j]!) + = result_init.coefficients.val[j]! + have h_set1 : ((((acc.1.coefficients.set k t1).set k t4).set k t6).val[j]!) + = (((acc.1.coefficients.set k t1).set k t4).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + ((acc.1.coefficients.set k t1).set k t4) k j t6 h_ne + have h_set2 : (((acc.1.coefficients.set k t1).set k t4).val[j]!) + = ((acc.1.coefficients.set k t1).val[j]!) := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne + (acc.1.coefficients.set k t1) k j t4 h_ne + have h_set3 : ((acc.1.coefficients.set k t1).val[j]!) + = acc.1.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.1.coefficients k j t1 h_ne + rw [h_set1, h_set2, h_set3] + exact h_acc_undone j h_ge' hj_lt + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) self_init message_init + { start := k, «end» := 16#usize } acc.1 acc.2 + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_fc h_body + show AddMessageErrorReduceFC.step_post self_init message_init result_init k (.done acc) + unfold AddMessageErrorReduceFC.step_post + show (AddMessageErrorReduceFC.inv self_init message_init result_init 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.1.coefficients.val[j]!) + = Spec.chunk_add_message_error_reduce_pure + (lift_chunk (self_init.coefficients.val[j]!)) + (lift_chunk (message_init.coefficients.val[j]!)) + (lift_chunk (result_init.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.1.coefficients.val[j]! = result_init.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L6.6 — `add_message_error_reduce`: combines `self · (R/128)` with + `result + message` then barrett. Returns `(re, scratch)` tuple; we + project on `re`. + + **Preconditions** (load-bearing, beyond the locked True-pre form): + - `h_result_bnd`: per-lane `|result[k][ℓ]| ≤ 32767` (consumed by + `mont_mul`'s legacy precondition). + - `h_sum_bnd`: per-lane `|self[k][ℓ] + message[k][ℓ]| ≤ 29439` (drives + both `add`s: first `self + message` directly; then `t1 + scratch2` + with |t1| ≤ 3328 + |scratch2| ≤ 29439 = 32767). -/ +@[spec] +theorem add_message_error_reduce_fc + (self message result : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (scratch : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_result_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((result.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) + (h_sum_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + (((self.coefficients.val[chunk]!).elements.val[ℓ]!).val + + ((message.coefficients.val[chunk]!).elements.val[ℓ]!).val + : Int).natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce + (vectortraitsOperationsInst := portable_ops_inst) self message result scratch + ⦃ ⇓ p => ⌜ lift_poly p.1 + = Spec.add_message_error_reduce_pure + (lift_poly self) (lift_poly message) (lift_poly result) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vre]; simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.add_message_error_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) self message + iter1 acc1.1 acc1.2) + (β := AddMessageErrorReduceFC.Acc) + (result, scratch) + 0#usize 16#usize + (AddMessageErrorReduceFC.inv self message result) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_⟩ + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _; trivial) + ?_) + · -- Post entailment. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (AddMessageErrorReduceFC.inv self message result 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.1.coefficients.val[j]!) + = Spec.chunk_add_message_error_reduce_pure + (lift_chunk (self.coefficients.val[j]!)) + (lift_chunk (message.coefficients.val[j]!)) + (lift_chunk (result.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.1.coefficients.val[j]! = result.coefficients.val[j]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + AddMessageErrorReduceFC.inv] using h_inv_holds + obtain ⟨h_done, _h_undone⟩ := h_inv + unfold Spec.add_message_error_reduce_pure + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_add_message_error_reduce_pure + (Spec.chunk_at (lift_poly self) k) + (Spec.chunk_at (lift_poly message) k) + (Spec.chunk_at (lift_poly result) k))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.1.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_add_message_error_reduce_pure + (Spec.chunk_at (lift_poly self) k) + (Spec.chunk_at (lift_poly message) k) + (Spec.chunk_at (lift_poly result) k)))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc self k hk, chunk_at_lift_poly_fc message k hk, + chunk_at_lift_poly_fc result k hk] + exact (h_done k hk).symm + have h_final := flatten_chunks_eq_lift_poly_fc r.1 chunks_arr h_chunks_len h_chunks_get + exact h_final.symm + · -- Step entailment. + intro acc k _h_ge h_le hinv + have h_step := + add_message_error_reduce_step_lemma_fc self message result + h_result_bnd h_sum_bnd acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : AddMessageErrorReduceFC.step_post self message result k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [AddMessageErrorReduceFC.step_post] using hP + · have hP : AddMessageErrorReduceFC.step_post self message result k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [AddMessageErrorReduceFC.step_post] using hP + +namespace SubtractReduceFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Step-local accumulator (the mutable `out` poly). -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `poly_reducing_from_i32_array_fc` (per-lane). + * (a) Chunks `j < k`, lanes `ℓ < 16`: per-lane mont equality + `lift_fe_mont acc[j][ℓ] = Spec.mont_reduce_pure (lift_fe_int a[16*j+ℓ])`. + * (b) Chunks `k ≤ j < 16`: `acc[j] = out_init[j]` (unchanged). + * (c) Chunks `j < k`, lanes `ℓ < 16`: per-lane I16 bound + `|acc[j][ℓ]| ≤ 4993`, propagated from + `reducing_from_i32_array_fc`'s strengthened POST. Used by L6.5 + (`add_standard_error_reduce_fc`) via `4993 ≤ 32767`. + Using a per-lane invariant avoids reasoning about sub-slice + equality at the chunk level (`lift_chunk_mont` vs sub-slice + `Spec.chunk_reducing_from_i32_array_pure`). -/ +def inv (a : Slice Std.I32) (out_init : Acc) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + lift_fe_mont ((acc.coefficients.val[j]!).elements.val[ℓ]!) + = Spec.mont_reduce_pure (lift_fe_int (a.val[16 * j + ℓ]!).val)) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.coefficients.val[j]! = out_init.coefficients.val[j]!) + ∧ (∀ j : Nat, j < k.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4993)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post (a : Slice Std.I32) (out_init : Acc) (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv a out_init iter'.start acc').holds + | .done y => (inv a out_init 16#usize y).holds + +end SubtractReduceFC + +/-- Sub-slice extraction `.ok`-form. Given a slice `a` and a range + `[start, end] ⊆ [0, a.length]`, the `core_models`-level slice index + succeeds with `s.val = a.val.slice start.val end.val` and + `s.val.length = end.val - start.val`. -/ +theorem slice_index_range_ok_eq_fc + {T : Type} (a : Slice T) (r : CoreModels.core.ops.range.Range Std.Usize) + (h0 : r.start.val ≤ r.end.val) (h1 : r.end.val ≤ a.val.length) : + ∃ s : Slice T, + core.Slice.Insts.CoreOpsIndexIndex.index + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice T) + a r = .ok s + ∧ s.val = a.val.slice r.start.val r.end.val + ∧ s.val.length = r.end.val - r.start.val := by + have hT := libcrux_iot_ml_kem.Util.SliceSpecs.core_models_Slice_Insts_index_RangeUsize_spec + (T := T) a r h0 h1 + obtain ⟨s, h_eq, h_post⟩ := triple_exists_ok_fc hT + exact ⟨s, h_eq, h_post.1, h_post.2⟩ + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for `poly.reducing_from_i32_array`. + Given a valid loop state `(acc, k)` with `k.val < 16`, the impl: + (i) extracts the sub-slice `s = a[16k..16(k+1)]`, + (ii) overwrites chunk `k` of `acc` with + `vector.reducing_from_i32_array s _` (the prior chunk value is + discarded by the vector op), + yielding the per-lane mont equality + `lift_fe_mont acc'[k][ℓ] = Spec.mont_reduce_pure (lift_fe_int a[16k+ℓ])` + for all ℓ, while preserving chunks `j ≠ k`. -/ +theorem poly_reducing_from_i32_array_step_lemma_fc + (a : Slice Std.I32) (out_init : SubtractReduceFC.Acc) + (hlen : a.length = 256) + (hbound : ∀ i : Nat, i < 256 → (a.val[i]!).val.natAbs ≤ 2^16 * 3328) + (acc : SubtractReduceFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (SubtractReduceFC.inv a out_init k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array_loop.body + (vectortraitsOperationsInst := portable_ops_inst) a + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ SubtractReduceFC.step_post a out_init k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_a_len : a.val.length = 256 := hlen + have h_coef_len : acc.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_acc_done, h_acc_undone, h_acc_bnd⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some i = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s_iter, hs_val, h_iter_some⟩ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) i1 := k * 16. + have hi1_max : k.val * (16#usize : Std.Usize).val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hum] + have h1 : k.val * 16 ≤ 15 * 16 := Nat.mul_le_mul_right 16 hk_15 + have : (15 * 16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i1, hi1_eq, hi1_val⟩ := usize_mul_ok_eq_fc k 16#usize hi1_max + have hi1_val_eq : i1.val = 16 * k.val := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hi1_val, hum]; omega + -- (2) i2 := k + 1. + have hi2_max : k.val + (1#usize : Std.Usize).val ≤ Std.Usize.max := by + have hk_15 : k.val ≤ 15 := by omega + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hum] + have : (16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i2, hi2_eq, hi2_val⟩ := usize_add_ok_eq_fc k 1#usize hi2_max + have hi2_val_eq : i2.val = k.val + 1 := by + have hum : (1#usize : Std.Usize).val = 1 := rfl + rw [hi2_val, hum] + -- (3) i3 := i2 * 16. + have hi3_max : i2.val * (16#usize : Std.Usize).val ≤ Std.Usize.max := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hum, hi2_val_eq] + have : k.val + 1 ≤ 16 := by omega + have h1 : (k.val + 1) * 16 ≤ 16 * 16 := Nat.mul_le_mul_right 16 this + have : (16 * 16 : Nat) ≤ Std.Usize.max := by scalar_tac + omega + obtain ⟨i3, hi3_eq, hi3_val⟩ := usize_mul_ok_eq_fc i2 16#usize hi3_max + have hi3_val_eq : i3.val = 16 * (k.val + 1) := by + have hum : (16#usize : Std.Usize).val = 16 := rfl + rw [hi3_val, hi2_val_eq, hum]; omega + -- (4) Sub-slice extraction. + have h0_le : i1.val ≤ i3.val := by rw [hi1_val_eq, hi3_val_eq]; omega + have hi3_le : i3.val ≤ a.val.length := by + rw [h_a_len, hi3_val_eq] + have : k.val + 1 ≤ 16 := by omega + have h1 : 16 * (k.val + 1) ≤ 16 * 16 := Nat.mul_le_mul_left _ this + omega + obtain ⟨s, h_s_eq, h_s_val, h_s_len⟩ := + slice_index_range_ok_eq_fc a { start := i1, «end» := i3 } h0_le hi3_le + have h_s_len16 : s.length = 16 := by + show s.val.length = 16 + rw [h_s_len] + show i3.val - i1.val = 16 + rw [hi3_val_eq, hi1_val_eq]; omega + -- (4a) Per-lane lookup: s.val[ℓ]! = a.val[16*k + ℓ]!. + have h_s_lane : ∀ ℓ : Nat, ℓ < 16 → + s.val[ℓ]! = a.val[16 * k.val + ℓ]! := by + intro ℓ hℓ + rw [h_s_val] + have h_idx_lt : i1.val + ℓ < i3.val := by + rw [hi1_val_eq, hi3_val_eq]; omega + have h_end_le : i3.val ≤ a.val.length := hi3_le + have h_bnd : i3.val ≤ a.val.length ∧ i1.val + ℓ < i3.val := ⟨h_end_le, h_idx_lt⟩ + rw [List.getElem!_slice i1.val i3.val ℓ a.val h_bnd] + rw [hi1_val_eq] + -- (4b) Per-lane bound for the sub-slice (consumed by `reducing_from_i32_array_fc`). + have h_s_bnd : ∀ ℓ : Nat, ℓ < 16 → + (s.val[ℓ]!).val.natAbs ≤ 2^16 * 3328 := by + intro ℓ hℓ + rw [h_s_lane ℓ hℓ] + apply hbound (16 * k.val + ℓ) + have : k.val ≤ 15 := by omega + have : 16 * k.val ≤ 16 * 15 := Nat.mul_le_mul_left 16 this + omega + -- (5) `index_mut_usize acc.coefficients k` → `(t, set_back) = (acc.coefs[k], acc.coefs.set k)`. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.coefficients.val[k.val]! with ht_def + have h_idx_t : Aeneas.Std.Array.index_usize acc.coefficients k = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.coefficients k + (by rw [h_coef_len]; exact hk_16) + have h_imt_t : Aeneas.Std.Array.index_mut_usize acc.coefficients k + = .ok (t, acc.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t]; rfl + -- (6) Apply `reducing_from_i32_array_fc` to get `t1` with the chunk FC equation + -- AND the per-lane I16 bound `|t1.elements[ℓ]| ≤ 4993` (used by L6.7's + -- strengthened POST conjunct (c) at chunk k). + obtain ⟨t1, h_t1_eq, h_t1_lift, h_t1_bnd⟩ := + triple_exists_ok_fc (reducing_from_i32_array_fc s t h_s_len16 h_s_bnd) + -- (7) Compose `a1 := acc.coefficients.set k t1`. + set a1 : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.coefficients.set k t1 with ha1_def + have h_a1_len : a1.length = 16 := by simp [ha1_def, h_coef_len] + have h_a1_k : a1.val[k.val]! = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + -- (8) Compose acc'. + set acc' : SubtractReduceFC.Acc := { coefficients := a1 } with hacc'_def + -- (9) Body equation. + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array_loop.body + (vectortraitsOperationsInst := portable_ops_inst) a + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let i1' ← (k * 16#usize : Result Std.Usize) + let i2' ← k + 1#usize + let i3' ← i2' * 16#usize + let s' ← + core.Slice.Insts.CoreOpsIndexIndex.index + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + Std.I32) a { start := i1', «end» := i3' } + let (t', index_mut_back) ← + Aeneas.Std.Array.index_mut_usize acc.coefficients k + let t1' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.reducing_from_i32_array s' t' + .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := index_mut_back t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + : Result (ControlFlow _ _)) + = _ + rw [hi1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [hi3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_s_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_imt_t] + simp only [Aeneas.Std.bind_tc_ok] + show ((do + let t1' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.reducing_from_i32_array s t + .ok (ControlFlow.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := acc.coefficients.set k t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + : Result _) + = _ + rw [h_t1_eq] + rfl + apply triple_of_ok_fc h_body + show SubtractReduceFC.step_post a out_init k + (.cont (({ start := s_iter, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold SubtractReduceFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (SubtractReduceFC.inv a out_init s_iter acc').holds + have h_inv_pure : + (∀ j : Nat, j < s_iter.val → ∀ ℓ : Nat, ℓ < 16 → + lift_fe_mont ((acc'.coefficients.val[j]!).elements.val[ℓ]!) + = Spec.mont_reduce_pure (lift_fe_int (a.val[16 * j + ℓ]!).val)) + ∧ (∀ j : Nat, s_iter.val ≤ j → j < 16 → + acc'.coefficients.val[j]! = out_init.coefficients.val[j]!) + ∧ (∀ j : Nat, j < s_iter.val → ∀ ℓ : Nat, ℓ < 16 → + ((acc'.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4993) := by + refine ⟨?_, ?_, ?_⟩ + · -- (a) j < s_iter.val = k+1 → per-lane FC. + intro j hj ℓ hℓ + rw [hs_val] at hj + show lift_fe_mont + (((acc.coefficients.set k t1).val[j]!).elements.val[ℓ]!) + = Spec.mont_reduce_pure (lift_fe_int (a.val[16 * j + ℓ]!).val) + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: chunk unchanged. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set] + exact h_acc_done j hj_lt_k ℓ hℓ + · -- j = k.val: chunk = t1; pull per-lane FC out of `h_t1_lift`. + subst hj_eq_k + have h_set_eq : ((acc.coefficients.set k t1).val[k.val]!) = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq] + -- Extract per-lane via `lift_chunk_mont t1 = Spec.chunk_reducing_from_i32_array_pure s`. + have h_t1_elems_len : t1.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length t1 + -- From `lift_chunk_mont`'s .val list form and Subtype.ext on the chunk eq. + have h_chunk_val : + t1.elements.val.map lift_fe_mont + = (List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (s.val[i]!).val)) := by + have h_unfold : (lift_chunk_mont t1).val + = (Spec.chunk_reducing_from_i32_array_pure s).val := by + rw [h_t1_lift] + unfold lift_chunk_mont Spec.chunk_reducing_from_i32_array_pure at h_unfold + exact h_unfold + -- Use `getElem!` form to dodge motive-not-type-correct issues. + have h_lhs_get : + (t1.elements.val.map lift_fe_mont)[ℓ]! + = lift_fe_mont (t1.elements.val[ℓ]!) := by + have h_len : (t1.elements.val.map lift_fe_mont).length = 16 := by + rw [List.length_map]; exact h_t1_elems_len + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map] + congr 1 + rw [getElem!_pos t1.elements.val ℓ (by rw [h_t1_elems_len]; exact hℓ)] + have h_rhs_get : + ((List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (s.val[i]!).val)))[ℓ]! + = Spec.mont_reduce_pure (lift_fe_int (s.val[ℓ]!).val) := by + have h_len : ((List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (s.val[i]!).val))).length = 16 := by + simp + rw [getElem!_pos _ ℓ (by rw [h_len]; exact hℓ)] + rw [List.getElem_map, List.getElem_range] + have h_lane : + lift_fe_mont (t1.elements.val[ℓ]!) + = Spec.mont_reduce_pure (lift_fe_int (s.val[ℓ]!).val) := by + rw [← h_lhs_get, h_chunk_val, h_rhs_get] + rw [h_lane] + -- Substitute s.val[ℓ]! = a.val[16*k.val + ℓ]!. + rw [h_s_lane ℓ hℓ] + · -- (b) s_iter.val ≤ j < 16 → acc'.coefs[j] = out_init.coefs[j]. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + show ((acc.coefficients.set k t1).val[j]!) + = out_init.coefficients.val[j]! + have h_set : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set] + exact h_acc_undone j h_ge' hj_lt + · -- (c) j < s_iter.val = k+1 → per-lane I16 bound `|acc'[j][ℓ]| ≤ 4993`. + intro j hj ℓ hℓ + rw [hs_val] at hj + show ((((acc.coefficients.set k t1).val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4993) + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: chunk unchanged, bound inherited from (c) at step k. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set] + exact h_acc_bnd j hj_lt_k ℓ hℓ + · -- j = k.val: chunk = t1; bound comes from `h_t1_bnd`. + subst hj_eq_k + have h_set_eq : ((acc.coefficients.set k t1).val[k.val]!) = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set_eq] + exact h_t1_bnd ℓ hℓ + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array_loop.body + (vectortraitsOperationsInst := portable_ops_inst) a + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_fc h_body + show SubtractReduceFC.step_post a out_init k (.done acc) + unfold SubtractReduceFC.step_post + show (SubtractReduceFC.inv a out_init 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (16#usize : Std.Usize).val → ∀ ℓ : Nat, ℓ < 16 → + lift_fe_mont ((acc.coefficients.val[j]!).elements.val[ℓ]!) + = Spec.mont_reduce_pure (lift_fe_int (a.val[16 * j + ℓ]!).val)) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.coefficients.val[j]! = out_init.coefficients.val[j]!) + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → ∀ ℓ : Nat, ℓ < 16 → + ((acc.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4993) := by + refine ⟨?_, ?_, ?_⟩ + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_done j _ ℓ hℓ; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + · intro j hj ℓ hℓ; rw [h16] at hj + apply h_acc_bnd j _ ℓ hℓ; rw [hk_eq]; exact hj + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L6.7 — poly-level `reducing_from_i32_array`. Returns a fresh poly + from an `i32` slice via 16 chunkwise `reducing_from_i32_array` calls. + + **Preconditions** (load-bearing, beyond the locked True-pre form): + - `hlen`: `a.length = 256` (the impl reads `a[16k..16(k+1)]` for k ∈ 0..16). + - `hbound`: per-lane `|a[i]| ≤ 2^16 * 3328` (consumed by the chunk-level + `reducing_from_i32_array_fc` precondition). -/ +@[spec] +theorem poly_reducing_from_i32_array_fc + (a : Slice Std.I32) + (out : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hlen : a.length = 256) + (hbound : ∀ i : Nat, i < 256 → (a.val[i]!).val.natAbs ≤ 2^16 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + (vectortraitsOperationsInst := portable_ops_inst) a out + ⦃ ⇓ p => ⌜ lift_poly_mont p = Spec.poly_reducing_from_i32_array_pure a + ∧ (∀ j : Nat, j < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((p.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4993) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vre]; simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, out1) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.reducing_from_i32_array_loop.body + (vectortraitsOperationsInst := portable_ops_inst) a iter1 out1) + (β := SubtractReduceFC.Acc) + out + 0#usize 16#usize + (SubtractReduceFC.inv a out) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_, ?_⟩ + · intro j hj; exact absurd hj (Nat.not_lt_zero j) + · intro _ _ _; trivial + · intro j hj; exact absurd hj (Nat.not_lt_zero j)) + ?_) + · -- Post entailment: at k=16, invariant gives per-lane FC for all 256 indices, + -- AND the per-lane I16 bound `|r[j][ℓ]| ≤ 4993` for all j < 16. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (SubtractReduceFC.inv a out 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j : Nat, j < (16#usize : Std.Usize).val → ∀ ℓ : Nat, ℓ < 16 → + lift_fe_mont ((r.coefficients.val[j]!).elements.val[ℓ]!) + = Spec.mont_reduce_pure (lift_fe_int (a.val[16 * j + ℓ]!).val)) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.coefficients.val[j]! = out.coefficients.val[j]!) + ∧ (∀ j : Nat, j < (16#usize : Std.Usize).val → ∀ ℓ : Nat, ℓ < 16 → + ((r.coefficients.val[j]!).elements.val[ℓ]!).val.natAbs ≤ 4993) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + SubtractReduceFC.inv] using h_inv_holds + obtain ⟨h_done, _h_undone, h_bnd⟩ := h_inv + refine ⟨?_, ?_⟩ + · -- Goal: `lift_poly_mont r = Spec.poly_reducing_from_i32_array_pure a`. + -- Both sides are 256-lane `Std.Array.make` constructions; reduce to + -- list equality via `Subtype.ext` and per-index via `List.ext_getElem`. + unfold lift_poly_mont Spec.poly_reducing_from_i32_array_pure + apply Subtype.ext + show (List.range 256).map (fun j => + lift_fe_mont (r.coefficients.val[j / 16]!).elements.val[j % 16]!) + = (List.range 256).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (a.val[i]!).val)) + apply List.ext_getElem + · simp + · intro j hj1 _hj2 + have hj : j < 256 := by + have : j < ((List.range 256).map (fun j' => + lift_fe_mont (r.coefficients.val[j' / 16]!).elements.val[j' % 16]!)).length := hj1 + simpa using this + have h_div_lt : j / 16 < 16 := Nat.div_lt_iff_lt_mul (by decide : 0 < 16) |>.mpr hj + have h_mod_lt : j % 16 < 16 := Nat.mod_lt _ (by decide : 0 < 16) + have h_decomp : 16 * (j / 16) + j % 16 = j := by + have := Nat.div_add_mod j 16 + omega + have h16' : (16#usize : Std.Usize).val = 16 := rfl + have h_inv_at := h_done (j / 16) (by rw [h16']; exact h_div_lt) (j % 16) h_mod_lt + simp only [List.getElem_map, List.getElem_range] + rw [h_inv_at, h_decomp] + · -- Goal: per-lane I16 bound for all j < 16, ℓ < 16. + intro j hj ℓ hℓ + have h16' : (16#usize : Std.Usize).val = 16 := rfl + exact h_bnd j (by rw [h16']; exact hj) ℓ hℓ + · -- Step entailment: per-iteration step lemma. + intro acc k _h_ge h_le hinv + have h_step := + poly_reducing_from_i32_array_step_lemma_fc a out hlen hbound acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : SubtractReduceFC.step_post a out k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [SubtractReduceFC.step_post] using hP + · have hP : SubtractReduceFC.step_post a out k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [SubtractReduceFC.step_post] using hP + + +end libcrux_iot_ml_kem.Polynomial.PolyOpsFc \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/PolyOpsFcBarrett.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/PolyOpsFcBarrett.lean new file mode 100644 index 00000000..7b7bef75 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Polynomial/PolyOpsFcBarrett.lean @@ -0,0 +1,448 @@ +/- + # `Polynomial/PolyOpsFcBarrett.lean` — FC theorems for `polynomial.rs` ops. + + Houses the §L6.{1,2,4,5,6,7} FC obligations from the original + `FCTargets.lean`. Lives separately from `Polynomial/PolyOps.lean` + to break a dependency cycle (the existing `PolyOps.lean` is imported + by `Polynomial/NttDrivers.lean`, while the FC theorems here depend + on `Polynomial/NttDrivers.lean` and feed into `Ntt.lean`'s + `ntt_binomially_sampled_ring_element_fc`). +-/ +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Spec.Pure +import LibcruxIotMlKem.Spec.ModularArith +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Extraction.Funs +import HacspecMlKem.Extraction.Funs + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + + +/-! ### Extracted from FCTargets.lean (§poly_l6_1). -/ + +namespace libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett +open libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §L6 — poly-level ops (6 theorems). -/ + +/-! ### L6.1.A — Loop scaffolding for `poly_barrett_reduce_fc`. + + FC invariant for the 16-iter chunk-loop. Each iteration `i ∈ 0..16` + applies `barrett_reduce` to chunk `i` of `self`, leaving chunks + `j ≠ i` untouched. The chunk-level closure for chunk `i` is + `lift_chunk acc'[i] = Spec.chunk_barrett_reduce_pure + (lift_chunk self[i])` + where `Spec.chunk_barrett_reduce_pure` lifts `Spec.barrett_pure` + pointwise across 16 lanes. -/ + +namespace BarrettReduceFC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +/-- Step-local accumulator (the mutable poly being barrett-reduced). -/ +abbrev Acc := + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector + +/-- FC loop invariant for `poly_barrett_reduce_fc`. + * (a) Chunks `j < k`: FC equation `lift_chunk acc[j] = + chunk_barrett_reduce_pure (lift_chunk self[j])`. + * (b) Chunks `k ≤ j < 16`: `acc[j] = self[j]` (unchanged). -/ +def inv + (self : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → Acc → Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + lift_chunk (acc.coefficients.val[j]!) + = Spec.chunk_barrett_reduce_pure + (lift_chunk (self.coefficients.val[j]!))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.coefficients.val[j]! = self.coefficients.val[j]!)) + +/-- Step-post for `loop_range_spec_usize`. -/ +def step_post + (self : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) × Acc) Acc) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (inv self iter'.start acc').holds + | .done y => (inv self 16#usize y).holds + +end BarrettReduceFC + +set_option maxHeartbeats 16000000 in +/-- Per-iteration FC step lemma for `poly_barrett_reduce`. Given a valid + loop state `(acc, k)` with `k.val < 16`, applies `barrett_reduce` to + chunk `k.val` of `acc`, recording the FC equation + `lift_chunk acc'[k.val] = chunk_barrett_reduce_pure + (lift_chunk self[k.val])` + while preserving chunks `j ≠ k.val`. -/ +theorem poly_barrett_reduce_step_lemma_fc + (self : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((self.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) + (acc : BarrettReduceFC.Acc) + (k : Std.Usize) (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (BarrettReduceFC.inv self k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ BarrettReduceFC.step_post self k r ⌝ ⦄ := by + have h16 : (16#usize : Std.Usize).val = 16 := rfl + have h_coef_len : acc.coefficients.length = 16 := + Std.Array.length_eq _ + obtain ⟨h_acc_done, h_acc_undone⟩ := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- `Some i = k` branch. + have hk_16 : k.val < 16 := by rw [h16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + -- (1) `index_mut_usize acc.coefficients k` → `(t, set_back) = (acc.coefs[k], acc.coefs.set k)`. + set t : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + acc.coefficients.val[k.val]! with ht_def + have h_idx_t : Aeneas.Std.Array.index_usize acc.coefficients k = .ok t := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.coefficients k + (by rw [h_coef_len]; exact hk_16) + have h_imt_t : Aeneas.Std.Array.index_mut_usize acc.coefficients k + = .ok (t, acc.coefficients.set k) := by + unfold Aeneas.Std.Array.index_mut_usize + rw [h_idx_t]; rfl + -- (1a) `t = self.coefficients[k]` (via h_acc_undone at j=k). + have h_t_eq : t = self.coefficients.val[k.val]! := by + show acc.coefficients.val[k.val]! = self.coefficients.val[k.val]! + exact h_acc_undone k.val (Nat.le_refl _) hk_16 + have h_t_bnd : ∀ ℓ : Nat, ℓ < 16 → + (t.elements.val[ℓ]!).val.natAbs ≤ 32767 := by + intro ℓ hℓ + rw [h_t_eq]; exact h_bnd k.val hk_16 ℓ hℓ + -- (2) `barrett_reduce t` → `t1`. Pre: |t[ℓ]| ≤ 32767 ✓. + obtain ⟨t1, h_t1_eq, h_t1_post⟩ := + triple_exists_ok_fc (barrett_reduce_fc t h_t_bnd) + obtain ⟨_h_t1_bnd, h_t1_lift⟩ := h_t1_post + -- (3) Compose acc' = `{ coefficients := acc.coefs.set k t1 }`. + set a : Std.Array libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector 16#usize := + acc.coefficients.set k t1 with ha_def + set acc' : BarrettReduceFC.Acc := { coefficients := a } with hacc'_def + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [Aeneas.Std.bind_tc_ok] + show (do + let (t', index_mut_back) ← + Aeneas.Std.Array.index_mut_usize acc.coefficients k + let t1' ← + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce t' + .ok (ControlFlow.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + ({ coefficients := index_mut_back t1' } + : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)))) + = _ + rw [h_imt_t]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_t1_eq] + rfl + apply triple_of_ok_fc h_body + show BarrettReduceFC.step_post self k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), acc')) + unfold BarrettReduceFC.step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (BarrettReduceFC.inv self s acc').holds + -- Invariant at (s, acc'): only chunk k changes (to t1). + have h_inv_pure : + (∀ j : Nat, j < s.val → + lift_chunk (acc'.coefficients.val[j]!) + = Spec.chunk_barrett_reduce_pure + (lift_chunk (self.coefficients.val[j]!))) + ∧ (∀ j : Nat, s.val ≤ j → j < 16 → + acc'.coefficients.val[j]! = self.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · -- (a) j < s.val → FC equation at chunk j. + intro j hj + rw [hs_val] at hj + show lift_chunk ((acc.coefficients.set k t1).val[j]!) = _ + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k.val: chunk j unchanged. + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set] + exact h_acc_done j hj_lt_k + · -- j = k.val: chunk j = t1. + subst hj_eq_k + have h_set : ((acc.coefficients.set k t1).val[k.val]!) = t1 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_eq acc.coefficients k k.val t1 + ⟨rfl, by rw [h_coef_len]; exact hk_16⟩ + rw [h_set] + -- Goal: lift_chunk t1 = chunk_barrett_reduce_pure (lift_chunk self[k]). + -- From h_t1_lift: lift_chunk t1 = chunk_barrett_reduce_pure (lift_chunk t). + -- And t = self.coefficients[k] (h_t_eq). + rw [h_t1_lift, h_t_eq] + · -- (b) s.val ≤ j < 16 → acc'.coefs[j] = self.coefs[j]. + intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + show ((acc.coefficients.set k t1).val[j]!) = self.coefficients.val[j]! + have h_set : ((acc.coefficients.set k t1).val[j]!) + = acc.coefficients.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using + Aeneas.Std.Array.getElem!_Nat_set_ne acc.coefficients k j t1 h_ne + rw [h_set] + exact h_acc_undone j h_ge' hj_lt + show (pure _ : Result Prop).holds + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + · -- `None` branch: k ≥ 16, done. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) + { start := k, «end» := 16#usize } acc + = .ok (ControlFlow.done acc) := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_fc h_body + show BarrettReduceFC.step_post self k (.done acc) + unfold BarrettReduceFC.step_post + show (BarrettReduceFC.inv self 16#usize acc).holds + show (pure _ : Result Prop).holds + have h_inv_pure : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (acc.coefficients.val[j]!) + = Spec.chunk_barrett_reduce_pure + (lift_chunk (self.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + acc.coefficients.val[j]! = self.coefficients.val[j]!) := by + refine ⟨?_, ?_⟩ + · intro j hj; rw [h16] at hj + apply h_acc_done j; rw [hk_eq]; exact hj + · intro j hj_ge hj_lt + rw [h16] at hj_ge + apply h_acc_undone j _ hj_lt; rw [hk_eq]; exact hj_ge + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] using h_inv_pure + +set_option maxHeartbeats 16000000 in +/-- L6.1 — `poly_barrett_reduce`: 16-chunk loop applying `barrett_reduce` + per chunk. Spec target: hacspec `polynomial.poly_barrett_reduce`. + + **Preconditions** (load-bearing, beyond the locked True-pre form): + - `h_bnd`: per-lane `|self[chunk][ℓ]| ≤ 32767` (consumed by + `barrett_reduce_fc`'s legacy precondition). + + Proof sketch: + 1. Unfold `VECTORS_IN_RING_ELEMENT = .ok 16#usize`. + 2. Apply `loop_range_spec_usize` with invariant `BarrettReduceFC.inv`. + 3. Per-iter step lemma (above) closes the body via `barrett_reduce_fc`. + 4. Post-entailment: at `k=16`, each chunk satisfies the FC equation + `lift_chunk r.coefs[k] = chunk_barrett_reduce_pure (lift_chunk self.coefs[k])`. + Build `chunks_arr` from this and use `flatten_chunks_eq_lift_poly_fc` to + get `flatten_chunks chunks_arr = lift_poly r`. Then bridge to the + hacspec post via `poly_barrett_reduce_eq_ok` (canonical lanes: `barrett_pure` + is identity on `lift_fe` images, so the pure projection coincides with + `flatten_chunks chunks_arr` pointwise). -/ +@[spec] +theorem poly_barrett_reduce_fc + (self : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bnd : ∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((self.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 32767) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce + (vectortraitsOperationsInst := portable_ops_inst) self + ⦃ ⇓ p => ⌜ hacspec_ml_kem.polynomial.poly_barrett_reduce (lift_poly self) + = .ok (lift_poly p) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce + -- Resolve `VECTORS_IN_RING_ELEMENT = .ok 16#usize`. + have h_vre : libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + = .ok (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.polynomial.VECTORS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.constants.COEFFICIENTS_IN_RING_ELEMENT + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + rfl + rw [h_vre]; simp only [Aeneas.Std.bind_tc_ok] + unfold libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, acc1) => + libcrux_iot_ml_kem.polynomial.PolynomialRingElement.poly_barrett_reduce_loop.body + (vectortraitsOperationsInst := portable_ops_inst) iter1 acc1) + (β := BarrettReduceFC.Acc) + self + 0#usize 16#usize + (BarrettReduceFC.inv self) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (by + show (pure _ : Result Prop).holds + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp] + intro _ + refine ⟨?_, ?_⟩ + · -- No chunks done yet. + intro j hj; exact absurd hj (Nat.not_lt_zero j) + · -- All chunks unchanged (acc = self) — but goal trivializes since acc = self + -- and the second conjunct's body is `acc.coefs[j]! = self.coefs[j]!` which is rfl. + intro _ _ _; trivial) + ?_) + · -- Post entailment: at k=16, the invariant gives all 16 FC equations. + rw [PostCond.entails_noThrow] + intro r hh + have h_inv_holds : (BarrettReduceFC.inv self 16#usize r).holds := by + simpa [PostCond.noThrow, Std.Do.SPred.down_pure] using hh + have h_inv : + (∀ j : Nat, j < (16#usize : Std.Usize).val → + lift_chunk (r.coefficients.val[j]!) + = Spec.chunk_barrett_reduce_pure + (lift_chunk (self.coefficients.val[j]!))) + ∧ (∀ j : Nat, (16#usize : Std.Usize).val ≤ j → j < 16 → + r.coefficients.val[j]! = self.coefficients.val[j]!) := by + simpa [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, + BarrettReduceFC.inv] using h_inv_holds + obtain ⟨h_done, _h_undone⟩ := h_inv + -- Build chunks_arr matching `chunk_barrett_reduce_pure (chunk_at (lift_poly self) k)`, + -- then apply `flatten_chunks_eq_lift_poly_fc` to get `flatten_chunks chunks_arr = lift_poly r`. + set chunks_arr : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_barrett_reduce_pure + (Spec.chunk_at (lift_poly self) k))) + (by simp) with hchunks_def + have h_chunks_len : chunks_arr.val.length = 16 := by + show ((List.range 16).map _).length = 16 + simp + have h_chunks_get : ∀ k : Nat, (hk : k < 16) → + chunks_arr.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (r.coefficients.val[k]!) := by + intro k hk + show ((List.range 16).map (fun k => + Spec.chunk_barrett_reduce_pure + (Spec.chunk_at (lift_poly self) k)))[k]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [chunk_at_lift_poly_fc self k hk] + exact (h_done k hk).symm + have h_flat := flatten_chunks_eq_lift_poly_fc r chunks_arr h_chunks_len h_chunks_get + -- Bridge to hacspec via `poly_barrett_reduce_eq_ok` + canonical-identity. + rw [libcrux_iot_ml_kem.Spec.Pure.polynomial.poly_barrett_reduce_eq_ok] + rw [libcrux_iot_ml_kem.Spec.Pure.polynomial.poly_barrett_reduce_pure_id_of_canonical + (lift_poly self) (lift_poly_lanes_canonical self)] + -- Goal: .ok (lift_poly self) = .ok (lift_poly r). Reduce via congrArg. + congr 1 + -- Goal: lift_poly self = lift_poly r. Chain through flatten_chunks. + rw [← h_flat] + -- Goal: lift_poly self = flatten_chunks chunks_arr. + apply Subtype.ext + unfold Spec.flatten_chunks lift_poly + show (List.range 256).map (fun j => + lift_fe (self.coefficients.val[j / 16]!).elements.val[j % 16]!) + = (List.range 256).map (fun j => (chunks_arr.val[j / 16]!).val[j % 16]!) + apply List.ext_getElem + · simp + · intro j hj1 _hj2 + have hj : j < 256 := by + have : j < ((List.range 256).map _).length := hj1 + simpa using this + have h_div_lt : j / 16 < 16 := Nat.div_lt_iff_lt_mul (by decide : 0 < 16) |>.mpr hj + have h_mod_lt : j % 16 < 16 := Nat.mod_lt _ (by decide : 0 < 16) + rw [List.getElem_map, List.getElem_map, List.getElem_range] + -- Pull chunks_arr.val[j/16]! through the definition (Std.Array.make's .val[!]). + have h_chunks_at : + chunks_arr.val[j / 16]! + = Spec.chunk_barrett_reduce_pure + (Spec.chunk_at (lift_poly self) (j / 16)) := by + rw [getElem!_pos chunks_arr.val (j / 16) (by rw [h_chunks_len]; exact h_div_lt)] + show ((List.range 16).map (fun k => + Spec.chunk_barrett_reduce_pure + (Spec.chunk_at (lift_poly self) k)))[j / 16]'_ = _ + rw [List.getElem_map, List.getElem_range] + rw [h_chunks_at] + -- Unfold chunk_barrett_reduce_pure: lane-wise barrett_pure. + unfold Spec.chunk_barrett_reduce_pure + -- The Std.Array.make's .val is the underlying list directly. + show lift_fe (self.coefficients.val[j / 16]!).elements.val[j % 16]! + = (((List.range 16).map (fun i => + Spec.barrett_pure ((Spec.chunk_at (lift_poly self) (j / 16)).val[i]!)))[j % 16]!) + rw [getElem!_pos ((List.range 16).map (fun i => + Spec.barrett_pure ((Spec.chunk_at (lift_poly self) (j / 16)).val[i]!))) + (j % 16) (by simp; exact h_mod_lt)] + rw [List.getElem_map, List.getElem_range] + -- chunk_at (lift_poly self) (j/16) = lift_chunk (self.coefs[j/16]!). + rw [chunk_at_lift_poly_fc self (j / 16) h_div_lt] + -- (lift_chunk x).val[j%16]! = lift_fe (x.elements.val[j%16]!). + have h_self_elems_len : (self.coefficients.val[j / 16]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_lc_self : ((lift_chunk (self.coefficients.val[j / 16]!)).val[j % 16]!) + = lift_fe ((self.coefficients.val[j / 16]!).elements.val[j % 16]!) := by + unfold lift_chunk + show (((self.coefficients.val[j / 16]!).elements.val.map lift_fe)[j % 16]!) = _ + have h_len : + ((self.coefficients.val[j / 16]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_self_elems_len + rw [getElem!_pos ((self.coefficients.val[j / 16]!).elements.val.map lift_fe) + (j % 16) (by rw [h_len]; exact h_mod_lt)] + rw [List.getElem_map] + rw [getElem!_pos (self.coefficients.val[j / 16]!).elements.val (j % 16) + (by rw [h_self_elems_len]; exact h_mod_lt)] + rw [h_lc_self] + rw [barrett_pure_lift_fe] + · -- Step lemma application. + intro acc k _h_ge h_le hinv + have h_step := poly_barrett_reduce_step_lemma_fc self h_bnd acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : BarrettReduceFC.step_post self k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [BarrettReduceFC.step_post] using hP + · have hP : BarrettReduceFC.step_post self k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [BarrettReduceFC.step_post] using hP + + +end libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/README.md b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/README.md new file mode 100644 index 00000000..014fbc1c --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/README.md @@ -0,0 +1,256 @@ +# ML-KEM matrix-arithmetic core: impl ↔ spec equivalence + +This directory contains the Lean 4 proof that the Rust implementation of +ML-KEM's **matrix-arithmetic core** in `libcrux-iot/ml-kem/src/` +computes the same functions as the hacspec-style specification in +`https://github.com/cryspen/libcrux`. Both sides are auto-extracted via the +`cargo hax into aeneas-lean` pipeline; this directory then proves their +functional-correctness (FC) equivalence. + +The four top-level results are the arithmetic heart of ML-KEM +key-generation, encryption, and decryption: `matrix.compute_As_plus_e`, +`matrix.compute_vector_u`, `matrix.compute_ring_element_v`, and `matrix.compute_message`. +The surrounding glue (XOF expansion, rejection sampling, (de)serialization, compression) is **not** proven +here — see [Assumptions](#assumptions-trust-boundary) for the precise +trust boundary. + +## Matrix-level theorems + +All four main results are `mvcgen` Triples of the form +`⦃ True ⦄ ⦃ ⇓ p => ⌜ (lift args…) = .ok (lift p…) ⌝ ⦄` +— i.e. they link the Aeneas-extracted impl to the hacspec spec through a `lift` bridge. +The `lift` bridge accounts for different representations of the input/output data: +The impl uses potentially non-canonical values mod 3329, +stores coefficients in the Montgomery domain, and +stores ring elements as 16 SIMD-shaped chunks of 16 lanes each. +In contrast, the spec uses canonical representations, plain coefficients, +and a flat array of 256 field elements. + +### L7.1 — key generation: `Â · ŝ + ê` + +[`Matrix/ComputeAsPlusE.lean`](Matrix/ComputeAsPlusE.lean) — `libcrux_iot_ml_kem.Matrix.ComputeAsPlusE.compute_As_plus_e_fc`: + +```lean +⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_As_plus_e + (vectortraitsOperationsInst := portable_ops_inst) + t_as_ntt matrix_A s_as_ntt error_as_ntt s_cache accumulator +⦃ ⇓ p => ⌜ hacspec_ml_kem.matrix.compute_As_plus_e + (lift_matrix_from_slice matrix_A K) + (lift_vec s_as_ntt) (lift_vec error_as_ntt) + = .ok (lift_vec p.1) ⌝ ⦄ +``` + +The impl's `compute_As_plus_e`, lifted, equals +the hacspec `compute_As_plus_e`. The matrix is read from a **stored** +array, so this theorem is fully +axiom-clean. + +### L7.2 — encryption: `Âᵀ · r̂ + ê₁` + +[`Matrix/ComputeVectorU/FC.lean`](Matrix/ComputeVectorU/FC.lean) — `libcrux_iot_ml_kem.Matrix.ComputeVectorU.FC.compute_vector_u_fc`: + +```lean +⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_vector_u + K (vectortraitsOperationsInst := portable_ops_inst) hash_functionsHashInst + matrix_entry seed r_as_ntt error_1 result scratch cache accumulator +⦃ ⇓ p => ⌜ hacspec_ml_kem.matrix.compute_vector_u + (lift_matrix_from_seed seed K) + (lift_vec_slice r_as_ntt K) + (lift_vec_slice error_1 K) + = .ok (lift_vec_slice p.2.1 K) ⌝ ⦄ +``` +The impl's `compute_vector_u`, lifted, equals +the hacspec `compute_vector_u`. Here the matrix is +**sampled on the fly** from `seed` (`lift_matrix_from_seed`), so this +theorem is conditional on the matrix-sampling leaf axiom **A1** (see +[Assumptions](#assumptions-trust-boundary)). + +### L7.3 — encryption: `t̂ · r̂ + e₂ + Decompress(message)` + +[`Matrix/ComputeRingElementV/FC.lean`](Matrix/ComputeRingElementV/FC.lean) — `libcrux_iot_ml_kem.Matrix.ComputeRingElementV.FC.compute_ring_element_v_fc`: + +```lean +⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_ring_element_v + K (vectortraitsOperationsInst := portable_ops_inst) + public_key t_as_ntt_entry r_as_ntt error_2 message result scratch + cache accumulator +⦃ ⇓ p => ⌜ hacspec_ml_kem.matrix.compute_ring_element_v + (lift_t_as_ntt_from_public_key public_key K) + (lift_vec_slice r_as_ntt K) + (lift_poly error_2) (lift_poly message) + = .ok (lift_poly p.2.1) ⌝ ⦄ +``` + +The impl's `compute_ring_element_v`, lifted, equals +the hacspec `compute_ring_element_v`. The first vector `t̂` is **deserialized** +from the public key (`lift_t_as_ntt_from_public_key`), so this theorem +is conditional on the deserialization leaf axiom **A2** (see +[Assumptions](#assumptions-trust-boundary)). + +### L7.4 — decryption: `NTT⁻¹(v̂ − ŝ · û)` + +[`Matrix/ComputeMessage/FC.lean`](Matrix/ComputeMessage/FC.lean) — `libcrux_iot_ml_kem.Matrix.ComputeMessage.FC.compute_message_fc`: + +```lean +⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.compute_message + (vectortraitsOperationsInst := portable_ops_inst) + v secret_as_ntt u_as_ntt result scratch accumulator +⦃ ⇓ p => ⌜ hacspec_ml_kem.matrix.compute_message + (lift_poly v) + (lift_vec secret_as_ntt) (lift_vec u_as_ntt) + = .ok (lift_poly p.1) ⌝ ⦄ +``` + +The impl's `compute_message`, lifted, equals +the hacspec `compute_message`. All inputs are passed-in polynomials, +so this theorem is fully axiom-clean. + +## Polynomial-level theorems + +The four matrix-level theorems above are assembled from a stack of +**polynomial-level** FC theorems — each over a single ring element +(`PolynomialRingElement` = 256 coefficients) — stated and proven in +the files listed below. Unlike L7.2/L7.3, **none** of these depend on +non-standard axioms. + +The polynomial-level theorems **do not use the hacspec implementation** +but use a pure Lean reference that reimplements the hacspec functions. + +### Number-theoretic transform operations + +| Theorem | impl function | what it does | +|---------|---------------|--------------| +| `libcrux_iot_ml_kem.Ntt.ntt_binomially_sampled_ring_element_fc` ([`Ntt.lean`](Ntt.lean)) | `ntt.ntt_binomially_sampled_ring_element` | forward NTT | +| `libcrux_iot_ml_kem.InvertNtt.invert_ntt_montgomery_fc` ([`InvertNtt.lean`](InvertNtt.lean)) | `invert_ntt.invert_ntt_montgomery` | inverse NTT | +| `libcrux_iot_ml_kem.Polynomial.NttMultiply.accumulating_ntt_multiply_fc` ([`Polynomial/NttMultiply.lean`](Polynomial/NttMultiply.lean)) | `vector.portable.ntt.accumulating_ntt_multiply` | pointwise NTT multiplication | + +### Reduction, error, and message combination + +The poly-level arithmetic that finishes each ML-KEM step. + +| Theorem | impl function | what it does | +|---------|-----------|--------------| +| `libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett.poly_barrett_reduce_fc` ([`Polynomial/PolyOpsFcBarrett.lean`](Polynomial/PolyOpsFcBarrett.lean)) | `polynomial.PolynomialRingElement.poly_barrett_reduce` | Barrett-reduce all 256 lanes to canonical residues | +| `libcrux_iot_ml_kem.Polynomial.PolyOpsFc.poly_reducing_from_i32_array_fc` ([`Polynomial/PolyOpsFc.lean`](Polynomial/PolyOpsFc.lean)) | `polynomial.PolynomialRingElement.reducing_from_i32_array` | Montgomery-reduce an `i32[256]` accumulator into a ring element | +| `libcrux_iot_ml_kem.Polynomial.PolyOpsFc.subtract_reduce_fc` ([`Polynomial/PolyOpsFc.lean`](Polynomial/PolyOpsFc.lean)) | `polynomial.PolynomialRingElement.subtract_reduce` | subtract two ring elements, then Barrett-reduce (decryption tail, L7.4) | +| `libcrux_iot_ml_kem.Polynomial.PolyOpsFc.add_error_reduce_fc` ([`Polynomial/PolyOpsFc.lean`](Polynomial/PolyOpsFc.lean)) | `polynomial.PolynomialRingElement.add_error_reduce` | add an error polynomial (impl's `1441`-Montgomery multiply), Barrett-reduce | +| `libcrux_iot_ml_kem.Polynomial.PolyOpsFc.add_standard_error_reduce_fc` ([`Polynomial/PolyOpsFc.lean`](Polynomial/PolyOpsFc.lean)) | `polynomial.PolynomialRingElement.add_standard_error_reduce` | add a standard error polynomial (`R`-Montgomery multiply), Barrett-reduce (keygen tail) | +| `libcrux_iot_ml_kem.Polynomial.PolyOpsFc.add_message_error_reduce_fc` ([`Polynomial/PolyOpsFc.lean`](Polynomial/PolyOpsFc.lean)) | `polynomial.PolynomialRingElement.add_message_error_reduce` | add error + message to the (`1441`-multiplied) result, Barrett-reduce (L7.3 tail) | + +## Assumptions (trust boundary) + +The four matrix-arithmetic theorems above are **complete proofs** modulo +the assumptions below. Read this section as the precise statement of what +is *trusted* rather than *proven*. + +### Standard Lean axioms + +Every theorem depends on Lean's three standard axioms: `propext`, +`Classical.choice`, `Quot.sound`. + +### Per-theorem axiom status + +| Theorem | Standard | Leaf axiom | +|---------|----------|------------| +| L7.1 `Matrix.ComputeAsPlusE.compute_As_plus_e_fc` | ✓ | — (fully clean) | +| L7.2 `Matrix.ComputeVectorU.FC.compute_vector_u_fc` | ✓ | **A1** `Sampling.sample_matrix_entry_fc` | +| L7.3 `Matrix.ComputeRingElementV.FC.compute_ring_element_v_fc` | ✓ | **A2** `Serialize.deserialize_to_reduced_ring_element_fc` | +| L7.4 `Matrix.ComputeMessage.FC.compute_message_fc` | ✓ | — (fully clean) | + +### The two deferred-leaf axioms (A1 / A2) + +- **A1** `libcrux_iot_ml_kem.Sampling.sample_matrix_entry_fc` (stated in + [`Sampling.lean`](Sampling.lean)) — characterizes one on-the-fly matrix + entry: running the impl's XOF + rejection-sampling chain on `(seed, i, j)` + produces the `(i, j)` entry of `lift_matrix_from_seed seed K` (row-major), + with every coefficient in `[0, 3328]`. + +- **A2** `libcrux_iot_ml_kem.Serialize.deserialize_to_reduced_ring_element_fc` + (stated in [`Serialize.lean`](Serialize.lean)) — characterizes one + 384-byte public-key chunk: running the impl's 16-iteration + `deserialize_12 + cond_subtract_3329` loop on chunk `i` produces + `(lift_t_as_ntt_from_public_key public_key K).val[i]!`, coefficients in + `[0, 3328]`. + +These are largly orthogonal to the matrix arithmetic, +which is why we omitted its verification. + +## Proof architecture + +### The lift bridge + +The impl works over `PortableVector`-backed `i16`/`i32` coefficients in +the (signed, possibly non-canonical) **Montgomery** domain; the hacspec +works over `parameters.FieldElement` (a `u16` wrapping `ZMod 3329`). The +lift family (in [`Spec/Lift.lean`](Spec/Lift.lean), namespace +`libcrux_iot_ml_kem.Spec.Lift`) maps impl values to canonical spec values. + +### Hierarchy (L0 → L7) + +The proof is structured into layers L0 to L7: + +| Layer | Content | +|-------|---------| +| **L0** | field-element arithmetic (`add`/`sub`/`mul`/`barrett`-reduce in `ZMod 3329`) | +| **L1** | per-vector-element ops (the `PortableVector` lane primitives) | +| **L2** | NTT butterfly layer steps (forward + inverse) | +| **L3** | NTT drivers (full forward/inverse NTT over the 7 layers) | +| **L4** | [*not verified*: sampling / compression] | +| **L5** | [*not verified*: (de)serialization] | +| **L6** | poly-level ops: barrett-reduce, subtract-reduce, add-error-reduce, add-message-error-reduce, reducing-from-`i32`-array | +| **L7** | the matrix-level targets above | + + +## Reproduction + +### Prerequisites + +- For running the proofs: + - Lean 4 toolchain `leanprover/lean4:v4.30.0-rc2` (pinned in `lean-toolchain`). + - Hacspec ML-KEM spec from https://github.com/cryspen/libcrux at commit `a4cfb1ebf26431b2ee81f0dc19383158aaf397b7` +- For extraction: + - Hax at commit `ffdf432705d409b62ec025d253a340234b59766f` + (not publicly available yet, https://github.com/cryspen/hax-evit) + with the corresponding charon/aeneas versions: + - Charon at https://github.com/AeneasVerif/charon/releases/tag/nightly-2026.06.02 + - Aeneas at https://github.com/cryspen/aeneas/releases/tag/nightly-2026.06.04 + — note: the `aeneas-pin` file in hax-evit at this commit names tag + `nightly-2026.06.03`, but commit `8d2077c` (the SHA the binary + must report) actually ships in `nightly-2026.06.04`. Use the + `06.04` release. + +### Verifying the Lean proof + +From `libcrux-iot/ml-kem/proofs/aeneas-lean/`: + +```bash +lake exe cache get +lake build +``` + +### Cross-spec regression (Rust) + +We have a couple of Rust tests in place as a first sanity check that +implementation and specification agree: + +```bash +cargo test --tests cross_spec +``` + +This catches mismatches at the Rust level before they propagate into Lean proof failures. + +### Extraction from Rust into Lean + +```bash +# Spec side (from a checkout of cryspen/libcrux): +cd specs/ml-kem/ +./hax_aeneas.py + +# Impl side: +cd libcrux-iot/ml-kem/ +./hax_aeneas.py +``` diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Sampling.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Sampling.lean new file mode 100644 index 00000000..db4dcb9d --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Sampling.lean @@ -0,0 +1,52 @@ +/- + # `Sampling.lean` — deferred leaf FC axiom for matrix sampling. + + We verify the matrix operations without verifying the underlying + rejection sampling. `sample_matrix_entry_fc` axiomatises the contract + that the impl `matrix.sample_matrix_entry` implements + `Spec.sample_matrix_A_pure`'s `(i, j)`-th entry with canonical + coefficients. +-/ + +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply + +set_option mvcgen.warning false +set_option linter.unusedVariables false + +namespace libcrux_iot_ml_kem.Sampling +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-- prepends `(i, j)` to the 32-byte seed, runs `sample_from_xof` + (rejection sampling on uniform [0, 2^12-1)), then `from_i16_array` + into `out`. Result has |coeff| ≤ 3328 (rejection sampling discards + values ≥ 3329). -/ +@[spec] +axiom sample_matrix_entry_fc + {Hasher : Type} + (hash_functionsHashInst : libcrux_iot_ml_kem.hash_functions.Hash Hasher) + (out : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (seed : Slice Std.U8) (i j K : Std.Usize) + (h_seed_len : seed.length = 32) + (h_i : i.val < K.val) (h_j : j.val < K.val) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.matrix.sample_matrix_entry + (vectortraitsOperationsInst := portable_ops_inst) + hash_functionsHashInst out seed i j + ⦃ ⇓ p => ⌜ lift_poly p = (lift_matrix_from_seed seed K).val[i.val]!.val[j.val]! + ∧ (∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((p.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 3328) ⌝ ⦄ + +end libcrux_iot_ml_kem.Sampling \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Serialize.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Serialize.lean new file mode 100644 index 00000000..88a4e7a6 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Serialize.lean @@ -0,0 +1,55 @@ +/- + # `Serialize.lean` — deferred leaf FC axiom for public-key deserialization. + + We verify the matrix operations without verifying the underlying + `serialize.deserialize_to_reduced_ring_element` ring-element decoding. + This axiom pins the contract: each 384-byte chunk of the public key + deserialises to the `i`-th ring element of + `Spec.t_as_ntt_from_public_key_pure` with canonical coefficients. +-/ + +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element +import LibcruxIotMlKem.Vector.Portable.Ntt +import LibcruxIotMlKem.Ntt +import LibcruxIotMlKem.InvertNtt +import LibcruxIotMlKem.Polynomial.NttDrivers +import LibcruxIotMlKem.Polynomial.PolyOps +import LibcruxIotMlKem.Polynomial.PolyOpsFcBarrett +import LibcruxIotMlKem.Polynomial.PolyOpsFc +import LibcruxIotMlKem.Polynomial.NttMultiply + +set_option mvcgen.warning false +set_option linter.unusedVariables false + +namespace libcrux_iot_ml_kem.Serialize +open libcrux_iot_ml_kem.InvertNtt libcrux_iot_ml_kem.Ntt libcrux_iot_ml_kem.Polynomial.NttMultiply libcrux_iot_ml_kem.Polynomial.PolyOpsFc libcrux_iot_ml_kem.Polynomial.PolyOpsFcBarrett libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement libcrux_iot_ml_kem.Vector.Portable.Ntt +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-- 16-iteration loop over 24-byte sub-chunks of `serialized` + (one `BYTES_PER_RING_ELEMENT = 384` byte ring-element). Each chunk + extracts 16 packed 12-bit coefficients via `deserialize_12`, then + `cond_subtract_3329` produces canonical residues [0, 3328]. -/ +@[spec] +axiom deserialize_to_reduced_ring_element_fc + (public_key : Slice Std.U8) (K : Std.Usize) + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (i : Std.Usize) + (h_pk_len : public_key.length = K.val * 384) + (h_i : i.val < K.val) + (chunk_bytes : Slice Std.U8) + (h_chunk_len : chunk_bytes.length = 384) + (h_chunk_eq : ∀ ℓ : Nat, ℓ < 384 → + chunk_bytes.val[ℓ]! = public_key.val[i.val * 384 + ℓ]!) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.serialize.deserialize_to_reduced_ring_element + (vectortraitsOperationsInst := portable_ops_inst) + chunk_bytes re + ⦃ ⇓ p => ⌜ lift_poly p = (lift_t_as_ntt_from_public_key public_key K).val[i.val]! + ∧ (∀ chunk : Nat, chunk < 16 → ∀ ℓ : Nat, ℓ < 16 → + ((p.coefficients.val[chunk]!).elements.val[ℓ]!).val.natAbs ≤ 3328) ⌝ ⦄ + +end libcrux_iot_ml_kem.Serialize \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec.lean new file mode 100644 index 00000000..ce0f5adb --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec.lean @@ -0,0 +1,600 @@ +/- + # `Spec.lean` — pure-Lean intermediate spec for ML-KEM. + + Mirrors `HacspecMlKem`'s `ntt`, `multiply_ntts`, `compress`, …, + but works on `MontPoly = Vector (ZMod 3329) 256` so that + `ring`/`field_simp` close the algebraic commute lemmas without + poking raw `% 3329` arithmetic. + + No `Aeneas.Std.Result`, no `mvcgen`, no impl-side Triple + obligations — only `Vector`, `ZMod 3329`, and the `bit_` + function signatures whose algebraic equivalence to the hacspec + spec lives. + + Design notes: + - `MontPoly := Vector (ZMod 3329) 256` is the algebraic working type. + The parallel `SpecPoly := Vector parameters.FieldElement 256` lives + below. + - `vector.traits.Operations` has no `repr` field; concrete impls + (e.g. `vector.portable.vector_type.PortableVector`) carry an + `elements : Array Std.I16 16` field accessed directly via + `re.coefficients.val[i]!.elements.val[j]!`. The lift functions + therefore specialize to `PortableVector`. + - hacspec spec functions return `Result`; bit-side `bit_` are + pure; `AlgEquiv` bridges via `Spec._pure` aliases. + - The NTT family (`bit_ntt`, `bit_ntt_layer_*`, `bit_invert_ntt_*`, + `bit_butterfly`, …) ships as identity placeholders so downstream + code can reference them by name; a later pass replaces these stubs + with real bodies and proves the algebraic equivalence. +-/ +import LibcruxIotMlKem.Spec.NumericKeystones +import LibcruxIotMlKem.Spec.ModularArith +import LibcruxIotMlKem.Spec.Montgomery +import LibcruxIotMlKem.Extraction.Funs +import HacspecMlKem.Extraction.Funs +import Mathlib.Data.ZMod.Basic +import Mathlib.Tactic.Ring + +namespace libcrux_iot_ml_kem.Spec +open CoreModels Aeneas Aeneas.Std + +/-! ### `Inhabited` instances for `.val[j]!` projections. + + The `PolynomialRingElement V`-and-`PortableVector` chunk types + need an `Inhabited` instance for the `coefficients.val[i]!` / + `elements.val[j]!` indexing patterns in `to_spec_poly_*` and the + `bit_*_form_poly` predicates. Declared `local` to avoid colliding + with the identically-shaped instances in `Equivalence/L3` and + `Equivalence/L6` files; both can coexist because they're scoped + to their respective files. -/ + +local instance instInhabitedPortableVector_bitMlKem : + Inhabited libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + ⟨{ elements := Std.Array.make 16#usize (List.replicate 16 (0#i16 : Std.I16)) + (by simp) }⟩ + +local instance instInhabitedPolynomialRingElement_bitMlKem + {Vector : Type} [Inhabited Vector] : + Inhabited (libcrux_iot_ml_kem.polynomial.PolynomialRingElement Vector) := + ⟨{ coefficients := Std.Array.make 16#usize (List.replicate 16 default) (by simp) }⟩ + +/-! ## §B.2 Type skeleton -/ + +/-- The algebraic working type for Layer-M commute proofs. Each ML-KEM + ring element is a 256-coefficient vector over `ZMod 3329`, so + `ring` / `field_simp` discharge the bulk of upstream's + `lemma_mod_*_distr_*` chains directly. + + Lifts from the impl side (a 16 × 16 `Array (PortableVector)`) + project each lane via `i16_to_spec_fe_{plain,mont}` below. -/ +abbrev MontPoly : Type := Vector (ZMod 3329) 256 + +/-! ## §B.5 — `SpecPoly` + lane coercions. -/ + +/-- The hacspec interface type: a 256-coefficient vector of + `parameters.FieldElement` (which wraps a `Std.U16` carrying a + canonical-form residue mod q). M.4 AlgEquiv lemmas bridge between + `bit_` (on `MontPoly`) and `Spec.` (on `SpecPoly`). -/ +abbrev SpecPoly : Type := + Vector hacspec_ml_kem.parameters.FieldElement 256 + +/-- `parameters.FieldElement → ZMod 3329` lane coercion. -/ +def zmodOfFE (fe : hacspec_ml_kem.parameters.FieldElement) : ZMod 3329 := + (fe.val.val : ZMod 3329) + +/-- `ZMod 3329 → parameters.FieldElement` lane coercion. Takes + `z.val : Fin 3329 ⊂ Fin 65536`, lifts to a `Std.U16`, and wraps. -/ +def feOfZMod (z : ZMod 3329) : hacspec_ml_kem.parameters.FieldElement := + { val := ⟨BitVec.ofNat 16 z.val⟩ } + +/-- Round-trip identity: lifting `z : ZMod 3329` to a FieldElement and + back yields `z`. Bridges M.4's "M.1 def equals hacspec spec value" + statements through the FE lift. -/ +theorem zmodOfFE_feOfZMod (z : ZMod 3329) : zmodOfFE (feOfZMod z) = z := by + unfold zmodOfFE feOfZMod + -- z.val < 3329 ≤ 65535, so BitVec.ofNat 16 z.val .toNat = z.val. + have h_lt : z.val < 65536 := + Nat.lt_of_lt_of_le (ZMod.val_lt z) (by decide) + have h_unfold : (BitVec.ofNat 16 z.val).toNat = z.val := by + simp [BitVec.toNat_ofNat, Nat.mod_eq_of_lt h_lt] + change ((BitVec.ofNat 16 z.val).toNat : ZMod 3329) = z + rw [h_unfold]; exact ZMod.natCast_zmod_val z + +/-- `MontPoly → SpecPoly` via per-lane `feOfZMod`. -/ +def MontPoly.toSpecPoly (m : MontPoly) : SpecPoly := m.map feOfZMod + +/-- `SpecPoly → MontPoly` via per-lane `zmodOfFE`. -/ +def SpecPoly.toMontPoly (s : SpecPoly) : MontPoly := s.map zmodOfFE + +/-! ## §B.4 (part) Lane-level lifts from `Std.I16` (impl side) -/ + +/-- Plain-domain lane lift: the i16 stores an integer representative + (possibly signed) of a value mod q. Cast through `Int → ZMod 3329` + and we're done. -/ +def i16_to_spec_fe_plain (x : Std.I16) : ZMod 3329 := + (x.val : ZMod 3329) + +/-- Mont-domain lane lift: the i16 stores `a · R mod q` for some + `a : ZMod 3329`; we strip the Montgomery factor by multiplying by + `R⁻¹ = 169`. -/ +def i16_to_spec_fe_mont (x : Std.I16) : ZMod 3329 := + ((x.val : ZMod 3329)) * (169 : ZMod 3329) + +/-- Unfold rule: `i16_to_spec_fe_plain` is the cast. Re-exported so + downstream Triple bodies can rewrite without re-unfolding the + definition. -/ +theorem i16_to_spec_fe_plain_unfold (x : Std.I16) : + i16_to_spec_fe_plain x = (x.val : ZMod 3329) := rfl + +/-- Unfold rule: `i16_to_spec_fe_mont` is `x.val · 169`. -/ +theorem i16_to_spec_fe_mont_unfold (x : Std.I16) : + i16_to_spec_fe_mont x = (x.val : ZMod 3329) * 169 := rfl + +/-! ## §B.4 (part) Poly-level lifts from `PolynomialRingElement PortableVector` -/ + +/-- Plain-domain poly lift: project each of the 16 chunks × 16 lanes + into `ZMod 3329` via `i16_to_spec_fe_plain`. The result is a + 256-element `Vector (ZMod 3329) 256`. -/ +def to_spec_poly_plain + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + MontPoly := + Vector.ofFn (n := 256) fun j => + let i := j.val / 16 + let k := j.val % 16 + let chunk := re.coefficients.val[i]! + let lane := chunk.elements.val[k]! + i16_to_spec_fe_plain lane + +/-- Mont-domain poly lift: same indexing scheme, but each lane is + multiplied by `R⁻¹ = 169` to strip the Montgomery factor. -/ +def to_spec_poly_mont + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + MontPoly := + Vector.ofFn (n := 256) fun j => + let i := j.val / 16 + let k := j.val % 16 + let chunk := re.coefficients.val[i]! + let lane := chunk.elements.val[k]! + i16_to_spec_fe_mont lane + +/-! ## §B.3 — the 34 `bit_` defs. + + Convention: + - "trivial" ops (`add`, `sub`, `barrett_reduce`, `multiply_*`, + `to_unsigned_*`, `cond_subtract_3329`) ship with REAL bodies + because they are 5-line pointwise vector maps. + - "complex" ops (NTT family, INTT, multiply_ntts, compress, + byte_encode, byte_decode) ship with STUBBED identity bodies + because their real bodies are 30-60 LOC and would expand the M.1 + dispatch beyond its single-agent budget. M.4 (AlgEquiv) will + replace each stub with the real body (and prove algebraic + equivalence to the hacspec spec). Each stub is marked + . +-/ + +/-! ### #1-2 — pointwise add / sub (real bodies, 5 LOC each) -/ + +/-- `#1 bit_add`: pointwise `(p, q) ↦ p + q` in `ZMod 3329`. Used by + L6.2 (`add_*_reduce` family). -/ +def bit_add (p q : MontPoly) : MontPoly := + Vector.ofFn (n := 256) fun i => p[i] + q[i] + +/-- `#2 bit_sub`: pointwise `(p, q) ↦ p - q` in `ZMod 3329`. Used by + L6.4 `subtract_reduce`. -/ +def bit_sub (p q : MontPoly) : MontPoly := + Vector.ofFn (n := 256) fun i => p[i] - q[i] + +/-! ### #3 — Barrett reduce (identity on `ZMod 3329`) -/ + +/-- `#3 bit_barrett_reduce`: identity on `ZMod 3329`, because Barrett + reduction takes an `I16` in `[-q·t, q·t]` and returns its + canonical residue mod q, which is the same `ZMod 3329` element. -/ +def bit_barrett_reduce (p : MontPoly) : MontPoly := p + +/-! ### #4-5 — multiply by constant (real bodies, 5 LOC each) -/ + +/-- `#4 bit_montgomery_multiply_by_constant`: pointwise `(p, c) ↦ + p · c · R⁻¹` in `ZMod 3329`. The `R⁻¹` factor is already absorbed + by the calling convention — the constant `c` is passed in the + Montgomery domain (`c · R`), so `c · R · R⁻¹ = c` and the result is + `p · c`. -/ +def bit_montgomery_multiply_by_constant (p : MontPoly) (c : ZMod 3329) : + MontPoly := + Vector.ofFn (n := 256) fun i => p[i] * c + +/-- `#5 bit_multiply_by_constant`: pointwise `(p, c) ↦ p · c` in + `ZMod 3329` (no Mont stripping; both operands in plain domain). -/ +def bit_multiply_by_constant (p : MontPoly) (c : ZMod 3329) : MontPoly := + Vector.ofFn (n := 256) fun i => p[i] * c + +/-! ### #6 — to_unsigned_representative (identity on `ZMod 3329`) -/ + +/-- `#6 bit_to_unsigned_representative`: identity on `ZMod 3329`. The + underlying impl reduces a signed I16 to its `[0, q)` canonical + representative, which is the same `ZMod 3329` element. -/ +def bit_to_unsigned_representative (p : MontPoly) : MontPoly := p + +/-! ### #7 — to_standard_domain (× R, equivalently × 2285 in canonical + form / × 1353 via the mont representation) -/ + +/-- `#7 bit_to_standard_domain`: pointwise `(p) ↦ p · R` in + `ZMod 3329`. In the hacspec spec, this is `compose with + montgomery_multiply(·, 1353)`, where `1353 = R² mod q` (B.4) — + composing with one mont reduce strips one `R⁻¹` and leaves + `· R²·R⁻¹ = · R`. -/ +def bit_to_standard_domain (p : MontPoly) : MontPoly := + Vector.ofFn (n := 256) fun i => p[i] * (2285 : ZMod 3329) + +/-! ### #8-12 — NTT layer family (STUBBED, real bodies) -/ + +/-- `#8 bit_ntt_layer_1` STUB. Real body: a butterfly pass over 128 + `(a, b)` pairs spaced 128 apart, using the layer-1 zetas + `Vector (ZMod 3329) 64`. M.4 fills the real body. -/ +def bit_ntt_layer_1 (p : MontPoly) (_zetas : Vector (ZMod 3329) 64) : + MontPoly := p + -- replace with real butterfly-pass body. + +/-- `#9 bit_ntt_layer_2` STUB. Real body: butterfly pass spaced 64, + 32-zeta argument. M.4 fills the real body. -/ +def bit_ntt_layer_2 (p : MontPoly) (_zetas : Vector (ZMod 3329) 32) : + MontPoly := p + -- replace with real butterfly-pass body. + +/-- `#10 bit_ntt_layer_3` STUB. Real body: butterfly pass spaced 32, + 16-zeta argument. M.4 fills the real body. -/ +def bit_ntt_layer_3 (p : MontPoly) (_zetas : Vector (ZMod 3329) 16) : + MontPoly := p + -- replace with real butterfly-pass body. + +/-- `#11 bit_ntt_layer_4_to_7` STUB. Parametric over layer index + `4 ≤ layer ≤ 7`; each layer halves the butterfly spacing. M.4 + fills the real body. -/ +def bit_ntt_layer_4_to_7 (p : MontPoly) (_layer : Nat) : MontPoly := p + -- replace with real layer-parametric butterfly body. + +/-- `#12 bit_ntt` STUB. Real body: 7-fold composition of the layer + NTT passes. M.4 fills the real body and proves equivalence to + `Spec.ntt_pure`. -/ +def bit_ntt (p : MontPoly) : MontPoly := p + -- replace with `bit_ntt_layer_1 ∘ … ∘ bit_ntt_layer_4_to_7 7`. + +/-! ### #13-17 — INTT family (STUBBED) -/ + +/-- `#13 bit_invert_ntt_layer_1` STUB. Real body: inverse-butterfly + pass mirroring `bit_ntt_layer_1`. M.4 fills the real body. -/ +def bit_invert_ntt_layer_1 (p : MontPoly) (_zetas : Vector (ZMod 3329) 64) : + MontPoly := p + -- replace with real inverse-butterfly body. + +/-- `#14 bit_invert_ntt_layer_2` STUB. -/ +def bit_invert_ntt_layer_2 (p : MontPoly) (_zetas : Vector (ZMod 3329) 32) : + MontPoly := p + -- replace with real inverse-butterfly body. + +/-- `#15 bit_invert_ntt_layer_3` STUB. -/ +def bit_invert_ntt_layer_3 (p : MontPoly) (_zetas : Vector (ZMod 3329) 16) : + MontPoly := p + -- replace with real inverse-butterfly body. + +/-- `#16 bit_invert_ntt_layer_4_to_7` STUB. -/ +def bit_invert_ntt_layer_4_to_7 (p : MontPoly) (_layer : Nat) : MontPoly := p + -- replace with real layer-parametric inverse body. + +/-- `#17 bit_invert_ntt_montgomery` STUB. Real body: 7-fold inverse + composition WITHOUT the final `· 1441 · R⁻¹` normalization (the + "INTT-Mont" form). The `bit_intt_mont_form_lane` predicate below + is the load-bearing per-lane invariant for this output. -/ +def bit_invert_ntt_montgomery (p : MontPoly) : MontPoly := p + -- replace with real INTT-without-finalize body. + +/-! ### #18-19 — multiply (STUBBED) -/ + +/-- `#18 bit_ntt_multiply_n` STUB. Real body: per-pair base-case + multiply across 128 `(a₀, a₁, b₀, b₁)` quartets, using + 64-element zeta argument. M.4 fills the real body. -/ +def bit_ntt_multiply_n (p q : MontPoly) (_zetas : Vector (ZMod 3329) 64) : + MontPoly := p + q -- harmless placeholder; replaced + -- replace with real per-pair base-case multiply body. + +/-- `#19 bit_multiply_ntts` STUB. Real body: wrapper around + `bit_ntt_multiply_n` with the hacspec zeta table. M.4 fills it. -/ +def bit_multiply_ntts (p q : MontPoly) : MontPoly := p + q + -- replace with real body. + +/-! ### #20-23 — per-quartet base cases (STUBBED) -/ + +/-- `#20 bit_base_case_multiply_even`: per-quartet helper for + `bit_ntt_multiply_n`. Real body: `a₀·b₀ + zeta·a₁·b₁`. M.4 + fills it. -/ +def bit_base_case_multiply_even + (_a0 _a1 _b0 _b1 _zeta : ZMod 3329) : ZMod 3329 := 0 + -- replace with `a0*b0 + zeta*a1*b1`. + +/-- `#21 bit_base_case_multiply_odd`: per-quartet helper. Real body: + `a₀·b₁ + a₁·b₀`. M.4 fills it. -/ +def bit_base_case_multiply_odd + (_a0 _a1 _b0 _b1 : ZMod 3329) : ZMod 3329 := 0 + -- replace with `a0*b1 + a1*b0`. + +/-- `#22 bit_butterfly`: per-pair NTT butterfly. Real body: + `(a + zeta·b, a - zeta·b)`. M.4 fills it. -/ +def bit_butterfly (_zeta a b : ZMod 3329) : + ZMod 3329 × ZMod 3329 := (a, b) + -- replace with `(a + zeta*b, a - zeta*b)`. + +/-- `#23 bit_inv_butterfly`: per-pair inverse butterfly. Real body: + `((a+b)/2, zeta·(a-b)/2)` modulo the Mont-domain bookkeeping. + M.4 fills it. -/ +def bit_inv_butterfly (_zeta a b : ZMod 3329) : + ZMod 3329 × ZMod 3329 := (a, b) + -- replace with the real inv-butterfly body. + +/-! ### #24-25 — poly ops (REAL where trivial, STUBBED where not) -/ + +/-- `#24 bit_add_to_ring_element`: pointwise add, same as `bit_add`. + Provided as an alias for the L7.2 caller chain. -/ +def bit_add_to_ring_element (p q : MontPoly) : MontPoly := bit_add p q + +/-- `#25 bit_subtract_reduce`: pointwise `(p, q) ↦ (q - p) · (R/128)` + in `ZMod 3329` — the L6.4 "subtract and finalize INTT" operation. + The `R/128` factor is exactly `1441 · R⁻¹ = 512` (mod q) by + `mont_128_169_512`. M.4 fills the real body. -/ +def bit_subtract_reduce (p q : MontPoly) : MontPoly := + Vector.ofFn (n := 256) fun i => (q[i] - p[i]) * (512 : ZMod 3329) + +/-! ### #26-29 — compress / decompress family (STUBBED) -/ + +/-- `#26 bit_compress`: compression by `d` bits. Real body: per-lane + `(2^d · x + ⌈q/2⌉) / q mod 2^d`. M.4 fills it (the spec uses + `Result` plumbing; the bit-side version is pure). -/ +def bit_compress (_p : MontPoly) (_d : Nat) : Vector (ZMod 3329) 256 := + Vector.replicate 256 (0 : ZMod 3329) + -- replace with real per-lane compression body; return type + -- should be `Vector (Fin (2^d)) 256` in the final shape — using + -- `Vector (ZMod 3329) 256` here as a uniform placeholder. + +/-- `#27 bit_decompress`: inverse compression. Real body: per-lane + `⌈q · y / 2^d⌉` for `y ∈ [0, 2^d)`. -/ +def bit_decompress (_c : Vector (ZMod 3329) 256) (_d : Nat) : MontPoly := + Vector.replicate 256 (0 : ZMod 3329) + -- replace with real per-lane decompression body. + +/-- `#28 bit_compress_message`: compress with `d = 1`. -/ +def bit_compress_message (p : MontPoly) : Vector (ZMod 3329) 256 := + bit_compress p 1 + +/-- `#29 bit_decompress_message`: decompress with `d = 1`. -/ +def bit_decompress_message (c : Vector (ZMod 3329) 256) : MontPoly := + bit_decompress c 1 + +/-! ### #30-31 — byte encode / decode (STUBBED) -/ + +/-- `#30 bit_byte_encode`: serialize a poly with `d` bits per + coefficient. Real body: bit-packing per FIPS-203. M.4 fills it + (likely via `bv_decide`). The output size is `32 * d` bytes; the + return type is parametric in `d` which makes a real signature + awkward here, so we ship a uniform placeholder. -/ +def bit_byte_encode (_p : MontPoly) (_d : Nat) : List Std.U8 := [] + -- real signature is `Vector Std.U8 (32 * d)`; replace + -- with the bit-packing body. + +/-- `#31 bit_byte_decode`: inverse of `bit_byte_encode`. -/ +def bit_byte_decode (_bytes : List Std.U8) (_d : Nat) : MontPoly := + Vector.replicate 256 (0 : ZMod 3329) + -- real signature is `Vector Std.U8 (32 * d) → MontPoly`; + -- replace with the bit-unpacking body. + +/-! ### #32 — cond_subtract_3329 (identity in `ZMod 3329`) -/ + +/-- `#32 bit_cond_subtract_3329`: identity on `ZMod 3329`. The impl + conditionally subtracts `q = 3329` when the lane is ≥ q, which is + a no-op modulo q. -/ +def bit_cond_subtract_3329 (p : MontPoly) : MontPoly := p + +/-! ### #33 — per-step helper (STUBBED) -/ + +/-- `#33 bit_ntt_layer_int_vec_step`: per-step helper extracted from + the L3.4 inner-loop sketch. Apply a single butterfly group with + spacing `group` and zeta `zeta`. M.4 fills the real body. -/ +def bit_ntt_layer_int_vec_step + (p : MontPoly) (_group : Nat) (_zeta : ZMod 3329) : MontPoly := p + -- replace with the per-step butterfly body. + +/-! ### #34 — accumulating multiply (STUBBED) -/ + +/-- `#34 bit_accumulating_ntt_multiply`: per-vector base case used by + L2.8 (`accumulating_ntt_multiply`). Real body returns an + accumulator update of 8 `(ZMod 3329)` values per call. M.4 fills + the real body. -/ +def bit_accumulating_ntt_multiply + (_a _b : Vector (ZMod 3329) 8) (_acc : Vector (ZMod 3329) 8) + (_zeta : ZMod 3329) : Vector (ZMod 3329) 8 := + Vector.replicate 8 (0 : ZMod 3329) + -- replace with the per-pair base-case accumulator body. + +/-! ## §B.4 Opaque predicates anchoring impl ↔ MontPoly per-lane facts -/ + +/-- "Lane carries an i16 in canonical Montgomery domain w.r.t. the + spec FE `expected`": `(lane.val · 169) ≡ expected (mod 3329)`. + + Marked `@[irreducible]` so L0+ Triple body proofs don't + accidentally unfold the predicate and trigger Z3 quantifier + cascades — they reveal it explicitly via + `bit_mont_form_lane_intro` / `…_reveal` below. -/ +@[irreducible] +def bit_mont_form_lane (lane : Std.I16) (expected : ZMod 3329) : Prop := + ((lane.val : ZMod 3329) * 169) = expected + +/-- "Lane carries an i16 in post-INTT-without-finalize Mont domain": + `(lane.val · 2285) ≡ (expected · 128) (mod 3329)`, where `2285 ≡ + R (mod q)` (B.3) and `128` reflects the deferred 1/128 normalization + from INTT. + + Marked `@[irreducible]` for the same reason as + `bit_mont_form_lane`. -/ +@[irreducible] +def bit_intt_mont_form_lane (lane : Std.I16) (expected : ZMod 3329) : Prop := + ((lane.val : ZMod 3329) * 2285) = expected * 128 + +/-! ### Per-chunk / per-poly wraps -/ + +/-- "Every lane of a 16-lane PortableVector chunk is in canonical Mont + form w.r.t. the corresponding lane of the 16-element expected + vector". Used by the chunk-level commutes in M.2 Block B. -/ +def bit_mont_form_chunk + (chunk : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (expected : Vector (ZMod 3329) 16) : Prop := + ∀ i : Fin 16, bit_mont_form_lane (chunk.elements.val[i.val]!) (expected[i.val]) + +/-- "Every lane of every chunk of the polynomial is in canonical Mont + form w.r.t. the corresponding lane of the 256-element expected + MontPoly". Used by the poly-level commutes in M.2 Block C. -/ +def bit_mont_form_poly + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (expected : MontPoly) : Prop := + ∀ i : Fin 16, ∀ j : Fin 16, + bit_mont_form_lane + ((re.coefficients.val[i.val]!).elements.val[j.val]!) + (expected[16 * i.val + j.val]'(by + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + omega)) + +/-- Per-chunk INTT-Mont form, mirroring `bit_mont_form_chunk`. -/ +def bit_intt_mont_form_chunk + (chunk : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (expected : Vector (ZMod 3329) 16) : Prop := + ∀ i : Fin 16, bit_intt_mont_form_lane (chunk.elements.val[i.val]!) (expected[i.val]) + +/-- Per-poly INTT-Mont form, mirroring `bit_mont_form_poly`. -/ +def bit_intt_mont_form_poly + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (expected : MontPoly) : Prop := + ∀ i : Fin 16, ∀ j : Fin 16, + bit_intt_mont_form_lane + ((re.coefficients.val[i.val]!).elements.val[j.val]!) + (expected[16 * i.val + j.val]'(by + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + omega)) + +/-! ### Reveal / intro lemmas for the opaque predicates. + + Mirrors the SHA-3 BitKeccak idiom (§5.7 Idiom 2): predicates are + `@[irreducible]` and consumers reveal them via these named + lemmas, never via direct `unfold`. -/ + +theorem bit_mont_form_lane_intro (lane : Std.I16) (expected : ZMod 3329) + (h : ((lane.val : ZMod 3329) * 169) = expected) : + bit_mont_form_lane lane expected := by + unfold bit_mont_form_lane; exact h + +theorem bit_mont_form_lane_reveal (lane : Std.I16) (expected : ZMod 3329) + (h : bit_mont_form_lane lane expected) : + ((lane.val : ZMod 3329) * 169) = expected := by + unfold bit_mont_form_lane at h; exact h + +theorem bit_intt_mont_form_lane_intro (lane : Std.I16) (expected : ZMod 3329) + (h : ((lane.val : ZMod 3329) * 2285) = expected * 128) : + bit_intt_mont_form_lane lane expected := by + unfold bit_intt_mont_form_lane; exact h + +theorem bit_intt_mont_form_lane_reveal (lane : Std.I16) (expected : ZMod 3329) + (h : bit_intt_mont_form_lane lane expected) : + ((lane.val : ZMod 3329) * 2285) = expected * 128 := by + unfold bit_intt_mont_form_lane at h; exact h + +/-! ## §B.5 Bridge lemmas — `to_spec_poly_*` projection lemmas. + + These collapse the `Vector.ofFn` definition of the lift back to the + per-lane formula. The four lemmas cover {plain, mont} × {chunk, + poly} and are used by M.2 / M.4 to push the lift inside per-lane + arithmetic. +-/ + +/-- Plain-domain poly lift unfolds to the per-lane formula at index + `16 i + j` for `i, j < 16`. -/ +theorem lemma_to_spec_poly_plain_unfold + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (i j : Fin 16) : + (to_spec_poly_plain re)[16 * i.val + j.val]'(by + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + omega) + = i16_to_spec_fe_plain + ((re.coefficients.val[i.val]!).elements.val[j.val]!) := by + unfold to_spec_poly_plain + -- `Vector.getElem_ofFn` rewrites the LHS to the body of `ofFn`. + simp only [Vector.getElem_ofFn] + -- Reduce `(16*i+j)/16` to `i` and `(16*i+j)%16` to `j` via `omega`. + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + have hdiv : (16 * i.val + j.val) / 16 = i.val := by omega + have hmod : (16 * i.val + j.val) % 16 = j.val := by omega + rw [hdiv, hmod] + +/-- Mont-domain poly lift unfolds to the per-lane formula. -/ +theorem lemma_to_spec_poly_mont_unfold + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (i j : Fin 16) : + (to_spec_poly_mont re)[16 * i.val + j.val]'(by + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + omega) + = i16_to_spec_fe_mont + ((re.coefficients.val[i.val]!).elements.val[j.val]!) := by + unfold to_spec_poly_mont + simp only [Vector.getElem_ofFn] + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + have hdiv : (16 * i.val + j.val) / 16 = i.val := by omega + have hmod : (16 * i.val + j.val) % 16 = j.val := by omega + rw [hdiv, hmod] + +/-- Plain-domain poly lift agrees lane-by-lane with any + pointwise-defined function that matches at every `(i, j)`. Used by + M.2 chunk/poly commutes to lift a Block-A lane fact to the full + 256-element vector. -/ +theorem lemma_to_spec_poly_plain_eq_of_coeffs + (re re' : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h : ∀ i j : Fin 16, + i16_to_spec_fe_plain + ((re.coefficients.val[i.val]!).elements.val[j.val]!) + = i16_to_spec_fe_plain + ((re'.coefficients.val[i.val]!).elements.val[j.val]!)) : + to_spec_poly_plain re = to_spec_poly_plain re' := by + unfold to_spec_poly_plain + apply Vector.ext + intro k hk + simp only [Vector.getElem_ofFn] + -- Decompose k = 16 * (k/16) + (k%16) with k/16 < 16 and k%16 < 16. + have hdiv_lt : k / 16 < 16 := by omega + have hmod_lt : k % 16 < 16 := Nat.mod_lt k (by decide) + exact h ⟨k / 16, hdiv_lt⟩ ⟨k % 16, hmod_lt⟩ + +/-- Mont-domain analogue of `lemma_to_spec_poly_plain_eq_of_coeffs`. -/ +theorem lemma_to_spec_poly_mont_eq_of_coeffs + (re re' : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h : ∀ i j : Fin 16, + i16_to_spec_fe_mont + ((re.coefficients.val[i.val]!).elements.val[j.val]!) + = i16_to_spec_fe_mont + ((re'.coefficients.val[i.val]!).elements.val[j.val]!)) : + to_spec_poly_mont re = to_spec_poly_mont re' := by + unfold to_spec_poly_mont + apply Vector.ext + intro k hk + simp only [Vector.getElem_ofFn] + have hdiv_lt : k / 16 < 16 := by omega + have hmod_lt : k % 16 < 16 := Nat.mod_lt k (by decide) + exact h ⟨k / 16, hdiv_lt⟩ ⟨k % 16, hmod_lt⟩ + +end libcrux_iot_ml_kem.Spec \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/AlgEquiv.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/AlgEquiv.lean new file mode 100644 index 00000000..dda2f803 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/AlgEquiv.lean @@ -0,0 +1,177 @@ +/- + # `Spec/AlgEquiv.lean` — algebraic-equivalence lemmas for the + easy `bit_` cluster. + + Each lemma characterises `bit_` pointwise on + `MontPoly = Vector (ZMod 3329) 256` in closed form, so callers can + unfold a single named theorem instead of inlining `Vector.getElem_ofFn` + chains. + + | `bit_` | Characterisation | Shape | + | --------------------------------------- | ------------------------ | ------ | + | `bit_add` | pointwise `+` | rfl | + | `bit_sub` | pointwise `-` | rfl | + | `bit_barrett_reduce` | identity | rfl | + | `bit_montgomery_multiply_by_constant` | pointwise `· c` | rfl | + | `bit_multiply_by_constant` | pointwise `· c` | rfl | + | `bit_to_unsigned_representative` | identity | rfl | + | `bit_to_standard_domain` | pointwise `· 2285` | rfl | + | `bit_cond_subtract_3329` | identity | rfl | + | `bit_add_to_ring_element` | `bit_add` | rfl | + | `bit_subtract_reduce` | pointwise `(q-p) · 512` | rfl | + + - The pointwise lemmas are `@[simp]`-eligible: each rewrites + `(bit_ ...)[i]'h` to a closed form on the operand lanes. + - The SpecPoly-bridge variants of the form + `bit_ (SpecPoly.toMontPoly p) = SpecPoly.toMontPoly (Spec._pure p)` + are deferred to the NTT-cluster pass because they require the + pure-projection side lemmas of `parameters.FieldElement.{add,sub,mul}`. + A locked target shape for the follow-up is captured inline at the end + of this file. + - Mathlib is imported for `ring`/`field_simp` on `ZMod 3329`. +-/ +import LibcruxIotMlKem.Spec +import LibcruxIotMlKem.Spec.Pure +import Mathlib.Tactic.Ring + +namespace libcrux_iot_ml_kem.Spec.AlgEquiv +open CoreModels Aeneas Aeneas.Std +open libcrux_iot_ml_kem.Spec + +/-! ## §M.4 Easy #1 — `bit_add` pointwise. -/ + +/-- `(bit_add p q)[i] = p[i] + q[i]`. -/ +@[scoped grind =] +theorem bit_add_getElem (p q : MontPoly) (i : Nat) (h : i < 256) : + (bit_add p q)[i]'h = p[i]'h + q[i]'h := by + unfold bit_add + simp [Vector.getElem_ofFn] + +/-! ## §M.4 Easy #2 — `bit_sub` pointwise. -/ + +/-- `(bit_sub p q)[i] = p[i] - q[i]`. -/ +@[scoped grind =] +theorem bit_sub_getElem (p q : MontPoly) (i : Nat) (h : i < 256) : + (bit_sub p q)[i]'h = p[i]'h - q[i]'h := by + unfold bit_sub + simp [Vector.getElem_ofFn] + +/-! ## §M.4 Easy #3 — `bit_barrett_reduce` is identity. -/ + +/-- `bit_barrett_reduce p = p` (identity in `ZMod 3329`; the impl-side + Barrett reduction picks a canonical residue mod q, which is the + same `ZMod 3329` element). -/ +@[scoped grind =] +theorem bit_barrett_reduce_eq (p : MontPoly) : bit_barrett_reduce p = p := rfl + +/-! ## §M.4 Easy #4 — `bit_montgomery_multiply_by_constant` pointwise. -/ + +/-- `(bit_montgomery_multiply_by_constant p c)[i] = p[i] * c`. The + Mont factor is already absorbed by the calling convention — the + constant is in Mont domain (`c · R`), and Mont multiplication + `· c · R · R⁻¹ = · c`, so the result is plain multiplication by `c`. -/ +@[scoped grind =] +theorem bit_montgomery_multiply_by_constant_getElem + (p : MontPoly) (c : ZMod 3329) (i : Nat) (h : i < 256) : + (bit_montgomery_multiply_by_constant p c)[i]'h = p[i]'h * c := by + unfold bit_montgomery_multiply_by_constant + simp [Vector.getElem_ofFn] + +/-! ## §M.4 Easy #5 — `bit_multiply_by_constant` pointwise. -/ + +/-- `(bit_multiply_by_constant p c)[i] = p[i] * c` (plain-domain + multiplication). -/ +@[scoped grind =] +theorem bit_multiply_by_constant_getElem + (p : MontPoly) (c : ZMod 3329) (i : Nat) (h : i < 256) : + (bit_multiply_by_constant p c)[i]'h = p[i]'h * c := by + unfold bit_multiply_by_constant + simp [Vector.getElem_ofFn] + +/-! ## §M.4 Easy #6 — `bit_to_unsigned_representative` is identity. -/ + +/-- `bit_to_unsigned_representative p = p`. The impl-side picks the + nonneg `[0, q)` representative; that is the same `ZMod 3329` element. -/ +@[scoped grind =] +theorem bit_to_unsigned_representative_eq (p : MontPoly) : + bit_to_unsigned_representative p = p := rfl + +/-! ## §M.4 Easy #7 — `bit_to_standard_domain` pointwise. -/ + +/-- `(bit_to_standard_domain p)[i] = p[i] * 2285`. The constant `2285` + is `R mod q` (B.3 keystone): the operation multiplies by `R²·R⁻¹` + = `R` because the underlying impl composes + `montgomery_multiply_fer_by_constant` with the constant `1353 = + R² mod q`. After Mont absorption we end up with `· R`. -/ +@[scoped grind =] +theorem bit_to_standard_domain_getElem + (p : MontPoly) (i : Nat) (h : i < 256) : + (bit_to_standard_domain p)[i]'h = p[i]'h * (2285 : ZMod 3329) := by + unfold bit_to_standard_domain + simp [Vector.getElem_ofFn] + +/-! ## §M.4 Easy #8 — `bit_cond_subtract_3329` is identity. -/ + +/-- `bit_cond_subtract_3329 p = p`. The impl-side conditionally + subtracts `q = 3329` to canonicalise the lane to `[0, q)`; in + `ZMod 3329` this is a no-op modulo q. -/ +@[scoped grind =] +theorem bit_cond_subtract_3329_eq (p : MontPoly) : + bit_cond_subtract_3329 p = p := rfl + +/-! ## §M.4 Easy #9 — `bit_add_to_ring_element` is `bit_add`. -/ + +/-- `bit_add_to_ring_element p q = bit_add p q`. The poly-level + `add_to_ring_element` impl wrapper is just chunked pointwise + addition, which lifts to `bit_add` exactly. -/ +@[scoped grind =] +theorem bit_add_to_ring_element_eq (p q : MontPoly) : + bit_add_to_ring_element p q = bit_add p q := rfl + +/-- Corollary: `bit_add_to_ring_element` pointwise. -/ +@[scoped grind =] +theorem bit_add_to_ring_element_getElem (p q : MontPoly) (i : Nat) (h : i < 256) : + (bit_add_to_ring_element p q)[i]'h = p[i]'h + q[i]'h := by + rw [bit_add_to_ring_element_eq]; exact bit_add_getElem p q i h + +/-! ## §M.4 Easy #10 — `bit_subtract_reduce` pointwise. -/ + +/-- `(bit_subtract_reduce p q)[i] = (q[i] - p[i]) * 512`. The factor + `512 = R · 128⁻¹ mod q = 1441 · 169 mod q` (B.1 keystone): the + impl computes "subtract and finalize INTT" by multiplying by + `R/128 mod q`. -/ +@[scoped grind =] +theorem bit_subtract_reduce_getElem (p q : MontPoly) (i : Nat) (h : i < 256) : + (bit_subtract_reduce p q)[i]'h = (q[i]'h - p[i]'h) * (512 : ZMod 3329) := by + unfold bit_subtract_reduce + simp [Vector.getElem_ofFn] + +/-! ## §M.4 SpecPoly bridges — deferred to NTT cluster. + + The full set of bridges of the form + `bit_ (SpecPoly.toMontPoly p) = SpecPoly.toMontPoly (Spec._pure p)` + requires per-`bit_` sub-lemmas `zmodOfFE__pure` that + distribute `zmodOfFE` through the hacspec `_pure` projection. Each + sub-lemma depends on the pure-projection side lemma for the + underlying `parameters.FieldElement.` (Open Question I.8 in + arch plan §F.2), which is not in scope for this easy-cluster + dispatch. + + The target shape (locked here for the follow-up dispatch): + ``` + theorem bit_add_specpoly_alg_eq (p q : SpecPoly) : + bit_add (SpecPoly.toMontPoly p) (SpecPoly.toMontPoly q) = + SpecPoly.toMontPoly + (Vector.ofFn fun i => Spec.Pure.FieldElement.add_pure (p[i]) (q[i])) := by + apply Vector.ext; intro k hk + rw [bit_add_getElem] + unfold SpecPoly.toMontPoly + simp only [Vector.getElem_map, Vector.getElem_ofFn] + exact (zmodOfFE_add_pure (p[k]) (q[k])).symm + ``` + -- closes once the deferred sub-lemma + -- `zmodOfFE (FieldElement.add_pure a b) = zmodOfFE a + zmodOfFE b` + -- lands (depends on `FieldElement.add_eq_ok` pure-projection). +-/ + +end libcrux_iot_ml_kem.Spec.AlgEquiv \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Commute.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Commute.lean new file mode 100644 index 00000000..a9903f65 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Commute.lean @@ -0,0 +1,976 @@ +/- + # `Spec/Commute.lean` — M.2 commute lemmas (Block A). + + Layer-0 scalar `Std.I16 → ZMod 3329` field-element commute lemmas. + Each lemma consumes an impl-level "value-equation" precondition + (already produced by L0/L1 Triples in the `Equivalence/` tree) and + produces the matching `ZMod 3329` algebraic equation through one of + the M.1 lane lifts `i16_to_spec_fe_{plain,mont}`. + + Port of `Hacspec_ml_kem.Commute.Chunk.fst` lines 35-680 (the + Layer-0 / "scalar" portion of Block A); blocks B, C, D are + deferred. + + ## Discipline + + Each lemma carries `@[scoped grind]` and lives inside the + `libcrux_iot_ml_kem.Spec.Commute` namespace, so consumers + enable `grind` over the commute set with + `open libcrux_iot_ml_kem.Spec.Commute` (no global pollution). + + ## File-shape notes + + - F* uses `v r % 3329 == ... % 3329` (Int arithmetic mod q). We + translate by stating the precondition directly as a `ZMod 3329` + equation — the M.1 lane lifts already give us the cast. + - F* uses `v r == v a + v b` (strict Int equality) for the strict + `_plain`/`_mont` variants. We mirror this with `r.val = a.val + b.val` + on `Std.I16.val : Int`. + - In `ZMod 3329`, `mont_i16_to_spec_fe x = x.val · 169` and + `i16_to_spec_fe_plain x = x.val`, so each F* `lemma_mod_*_distr_*` + chain collapses to a single `rw [hr]; ring`. +-/ +import LibcruxIotMlKem.Spec +import Mathlib.Data.ZMod.Basic +import Mathlib.Tactic.Ring + +namespace libcrux_iot_ml_kem.Spec.Commute +open CoreModels Aeneas Aeneas.Std +open libcrux_iot_ml_kem.Spec + +/-! ### Local `Inhabited` instances (mirror of `Spec.lean`). + + The `PolynomialRingElement V`-and-`PortableVector` chunk types + need an `Inhabited` instance for the `coefficients.val[i]!` / + `elements.val[j]!` indexing patterns used by Block-C poly lemma + statements. `Spec.lean` declares the same instances as `local`, so + they don't propagate here; we redeclare them `local` for this file. -/ + +local instance instInhabitedPortableVector_bitMlKemCommute : + Inhabited libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + ⟨{ elements := Std.Array.make 16#usize (List.replicate 16 (0#i16 : Std.I16)) + (by simp) }⟩ + +local instance instInhabitedPolynomialRingElement_bitMlKemCommute + {Vector : Type} [Inhabited Vector] : + Inhabited (libcrux_iot_ml_kem.polynomial.PolynomialRingElement Vector) := + ⟨{ coefficients := Std.Array.make 16#usize (List.replicate 16 default) (by simp) }⟩ + +/-! ## A.1 / A.2 — pointwise addition commutes (plain and Mont). -/ + +/-- A.1 `lemma_add_fe_commute_plain` (F*: Chunk.fst:35). Strict + Int-equality precondition: the impl returns the exact integer + sum (no overflow), and the plain-domain lift respects this. -/ +@[scoped grind .] +theorem lemma_add_fe_commute_plain (a b r : Std.I16) + (hr : r.val = a.val + b.val) : + i16_to_spec_fe_plain a + i16_to_spec_fe_plain b = i16_to_spec_fe_plain r := by + unfold i16_to_spec_fe_plain + rw [hr]; push_cast; ring + +/-- A.2 `lemma_add_fe_commute_mont` (F*: Chunk.fst:41). Same shape + as A.1 but lifts through `mont` (extra `· 169` factor). -/ +@[scoped grind .] +theorem lemma_add_fe_commute_mont (a b r : Std.I16) + (hr : r.val = a.val + b.val) : + i16_to_spec_fe_mont a + i16_to_spec_fe_mont b = i16_to_spec_fe_mont r := by + unfold i16_to_spec_fe_mont + rw [hr]; push_cast; ring + +/-! ## A.3 / A.4 — pointwise subtraction commutes. -/ + +/-- A.3 `lemma_sub_fe_commute_plain` (F*: Chunk.fst:48). -/ +@[scoped grind .] +theorem lemma_sub_fe_commute_plain (a b r : Std.I16) + (hr : r.val = a.val - b.val) : + i16_to_spec_fe_plain a - i16_to_spec_fe_plain b = i16_to_spec_fe_plain r := by + unfold i16_to_spec_fe_plain + rw [hr]; push_cast; ring + +/-- A.4 `lemma_sub_fe_commute_mont` (F*: Chunk.fst:54). -/ +@[scoped grind .] +theorem lemma_sub_fe_commute_mont (a b r : Std.I16) + (hr : r.val = a.val - b.val) : + i16_to_spec_fe_mont a - i16_to_spec_fe_mont b = i16_to_spec_fe_mont r := by + unfold i16_to_spec_fe_mont + rw [hr]; push_cast; ring + +/-! ## A.5 — Barrett reduction commutes (plain). -/ + +/-- A.5 `lemma_barrett_fe_commute` (F*: Chunk.fst:63). Barrett + reduction preserves residue mod q, so the plain lift is identity + on the reduced value. + + Statement uses `r = a` order (matching F*) — that is, + `i16_to_spec_fe_plain r = i16_to_spec_fe_plain a`. -/ +@[scoped grind .] +theorem lemma_barrett_fe_commute (a r : Std.I16) + (hr : (r.val : ZMod 3329) = (a.val : ZMod 3329)) : + i16_to_spec_fe_plain r = i16_to_spec_fe_plain a := by + unfold i16_to_spec_fe_plain + exact hr + +/-! ## A.6 — Mont zeta cancellation (mont ↔ plain bridge). -/ + +/-- A.6 `lemma_mont_zeta_cancel` (F*: Chunk.fst:71). The impl stores + zetas pre-multiplied by `R`, so the Mont lift recovers the plain + abstract zeta when paired with a residue-equality precondition. -/ +@[scoped grind .] +theorem lemma_mont_zeta_cancel (zeta_mont zeta_plain : Std.I16) + (hr : (zeta_mont.val : ZMod 3329) * 169 = (zeta_plain.val : ZMod 3329)) : + i16_to_spec_fe_mont zeta_mont = i16_to_spec_fe_plain zeta_plain := by + unfold i16_to_spec_fe_mont i16_to_spec_fe_plain + exact hr + +/-! ## A.7 / A.8 — mod-aware add/sub commutes (residue precondition). -/ + +/-- A.7 `lemma_add_fe_commute_mont_mod` (F*: Chunk.fst:151). The + precondition is a `ZMod 3329` equality rather than a strict Int + sum — used by butterfly outputs whose impl post is mod q. -/ +@[scoped grind .] +theorem lemma_add_fe_commute_mont_mod (a b r : Std.I16) + (hr : (r.val : ZMod 3329) = (a.val : ZMod 3329) + (b.val : ZMod 3329)) : + i16_to_spec_fe_mont a + i16_to_spec_fe_mont b = i16_to_spec_fe_mont r := by + unfold i16_to_spec_fe_mont + rw [hr]; ring + +/-- A.8 `lemma_sub_fe_commute_mont_mod` (F*: Chunk.fst:159). -/ +@[scoped grind .] +theorem lemma_sub_fe_commute_mont_mod (a b r : Std.I16) + (hr : (r.val : ZMod 3329) = (a.val : ZMod 3329) - (b.val : ZMod 3329)) : + i16_to_spec_fe_mont a - i16_to_spec_fe_mont b = i16_to_spec_fe_mont r := by + unfold i16_to_spec_fe_mont + rw [hr]; ring + +/-! ## A.9 / A.10 — butterfly commute (plus and minus halves). -/ + +/-- A.9 `lemma_butterfly_fe_commute_plus` (F*: Chunk.fst:187). The + `+` output of a `ntt_layer_*_step` butterfly: in the Mont domain, + the impl post `result_i ≡ vec_i + vec_j · zeta · 169 (mod q)` + collapses to the FE equation `mont_fe result_i = mont_fe vec_i + + mont_fe zeta · mont_fe vec_j` because the Montgomery factor + cancels exactly with the `· 169` in the residue. -/ +@[scoped grind .] +theorem lemma_butterfly_fe_commute_plus + (vec_i vec_j zeta result_i : Std.I16) + (hr : (result_i.val : ZMod 3329) = + (vec_i.val : ZMod 3329) + + (vec_j.val : ZMod 3329) * (zeta.val : ZMod 3329) * 169) : + i16_to_spec_fe_mont result_i = + i16_to_spec_fe_mont vec_i + + i16_to_spec_fe_mont zeta * i16_to_spec_fe_mont vec_j := by + unfold i16_to_spec_fe_mont + rw [hr]; ring + +/-- A.10 `lemma_butterfly_fe_commute_minus` (F*: Chunk.fst:217). -/ +@[scoped grind .] +theorem lemma_butterfly_fe_commute_minus + (vec_i vec_j zeta result_j : Std.I16) + (hr : (result_j.val : ZMod 3329) = + (vec_i.val : ZMod 3329) - + (vec_j.val : ZMod 3329) * (zeta.val : ZMod 3329) * 169) : + i16_to_spec_fe_mont result_j = + i16_to_spec_fe_mont vec_i - + i16_to_spec_fe_mont zeta * i16_to_spec_fe_mont vec_j := by + unfold i16_to_spec_fe_mont + rw [hr]; ring + +/-! ## A.11 — combined butterfly pair (both halves). -/ + +/-- A.11 `lemma_butterfly_pair_commute` (F*: Chunk.fst:234). Bundles + A.9 and A.10 into a single ∧ — one call per butterfly pair + instead of two. The F* version threads through `Seq.index`; we + stay scalar at Block A and take the four `Std.I16` lanes + directly. Block B re-introduces the array machinery. -/ +@[scoped grind .] +theorem lemma_butterfly_pair_commute + (vec_i vec_j result_i result_j zeta : Std.I16) + (hr_i : (result_i.val : ZMod 3329) = + (vec_i.val : ZMod 3329) + + (vec_j.val : ZMod 3329) * (zeta.val : ZMod 3329) * 169) + (hr_j : (result_j.val : ZMod 3329) = + (vec_i.val : ZMod 3329) - + (vec_j.val : ZMod 3329) * (zeta.val : ZMod 3329) * 169) : + i16_to_spec_fe_mont result_i = + i16_to_spec_fe_mont vec_i + + i16_to_spec_fe_mont zeta * i16_to_spec_fe_mont vec_j + ∧ + i16_to_spec_fe_mont result_j = + i16_to_spec_fe_mont vec_i - + i16_to_spec_fe_mont zeta * i16_to_spec_fe_mont vec_j := by + exact ⟨lemma_butterfly_fe_commute_plus vec_i vec_j zeta result_i hr_i, + lemma_butterfly_fe_commute_minus vec_i vec_j zeta result_j hr_j⟩ + +/-! ## A.12 — inverse butterfly multiply-diff. -/ + +/-- A.12 `lemma_inv_butterfly_fe_commute_mul_diff` (F*: Chunk.fst:279). + The `j` output of the Gentleman–Sande inverse butterfly. -/ +@[scoped grind .] +theorem lemma_inv_butterfly_fe_commute_mul_diff + (vec_i vec_j zeta result_j : Std.I16) + (hr : (result_j.val : ZMod 3329) = + ((vec_j.val : ZMod 3329) - (vec_i.val : ZMod 3329)) * + (zeta.val : ZMod 3329) * 169) : + i16_to_spec_fe_mont result_j = + i16_to_spec_fe_mont zeta * + (i16_to_spec_fe_mont vec_j - i16_to_spec_fe_mont vec_i) := by + unfold i16_to_spec_fe_mont + rw [hr]; ring + +/-! ## A.16 / A.17 — base-case multiply commutes (even and odd halves). + + These are the upstream `Z3rlimit 400` lemmas (~80 LOC F* each + with explicit `lemma_mod_*_distr_*` chains). In Lean the same + statement falls to `rw [hr]; ring` in `ZMod 3329` because the + Montgomery-factor algebra is just commutative-ring distribution. -/ + +/-- A.16 `lemma_base_case_mult_even_fe_commute` (F*: Chunk.fst:414). -/ +@[scoped grind .] +theorem lemma_base_case_mult_even_fe_commute + (a0 b0 a1 b1 zeta result : Std.I16) + (hr : (result.val : ZMod 3329) = + ((a0.val : ZMod 3329) * (b0.val : ZMod 3329) + + (a1.val : ZMod 3329) * (b1.val : ZMod 3329) * + (zeta.val : ZMod 3329) * 169) * 169) : + i16_to_spec_fe_mont result = + i16_to_spec_fe_mont a0 * i16_to_spec_fe_mont b0 + + i16_to_spec_fe_mont a1 * i16_to_spec_fe_mont b1 * + i16_to_spec_fe_mont zeta := by + unfold i16_to_spec_fe_mont + rw [hr]; ring + +/-- A.17 `lemma_base_case_mult_odd_fe_commute` (F*: Chunk.fst:523). -/ +@[scoped grind .] +theorem lemma_base_case_mult_odd_fe_commute + (a0 b1 a1 b0 result : Std.I16) + (hr : (result.val : ZMod 3329) = + ((a0.val : ZMod 3329) * (b1.val : ZMod 3329) + + (a1.val : ZMod 3329) * (b0.val : ZMod 3329)) * 169) : + i16_to_spec_fe_mont result = + i16_to_spec_fe_mont a0 * i16_to_spec_fe_mont b1 + + i16_to_spec_fe_mont a1 * i16_to_spec_fe_mont b0 := by + unfold i16_to_spec_fe_mont + rw [hr]; ring + +/-! ## A.18 — combined base-case multiply pair (both halves). -/ + +/-- A.18 `lemma_base_case_mult_pair_commute` (F*: Chunk.fst:547). + Bundles A.16 / A.17 — one call per binomial pair. -/ +@[scoped grind .] +theorem lemma_base_case_mult_pair_commute + (a0 b0 a1 b1 zeta result_even result_odd : Std.I16) + (hr_e : (result_even.val : ZMod 3329) = + ((a0.val : ZMod 3329) * (b0.val : ZMod 3329) + + (a1.val : ZMod 3329) * (b1.val : ZMod 3329) * + (zeta.val : ZMod 3329) * 169) * 169) + (hr_o : (result_odd.val : ZMod 3329) = + ((a0.val : ZMod 3329) * (b1.val : ZMod 3329) + + (a1.val : ZMod 3329) * (b0.val : ZMod 3329)) * 169) : + i16_to_spec_fe_mont result_even = + i16_to_spec_fe_mont a0 * i16_to_spec_fe_mont b0 + + i16_to_spec_fe_mont a1 * i16_to_spec_fe_mont b1 * + i16_to_spec_fe_mont zeta + ∧ + i16_to_spec_fe_mont result_odd = + i16_to_spec_fe_mont a0 * i16_to_spec_fe_mont b1 + + i16_to_spec_fe_mont a1 * i16_to_spec_fe_mont b0 := by + exact ⟨lemma_base_case_mult_even_fe_commute a0 b0 a1 b1 zeta result_even hr_e, + lemma_base_case_mult_odd_fe_commute a0 b1 a1 b0 result_odd hr_o⟩ + +/-! ## A.19 / A.20 — Montgomery multiplication commutes. -/ + +/-- A.19 `lemma_mont_mul_fe_commute_mont_mont` (F*: Chunk.fst:615). + Two Mont-domain operands: the impl's `· R⁻¹` cancels the Mont + lift's extra factor, yielding plain FE multiplication. -/ +@[scoped grind .] +theorem lemma_mont_mul_fe_commute_mont_mont (a b r : Std.I16) + (hr : (r.val : ZMod 3329) = + (a.val : ZMod 3329) * (b.val : ZMod 3329) * 169) : + i16_to_spec_fe_mont a * i16_to_spec_fe_mont b = i16_to_spec_fe_mont r := by + unfold i16_to_spec_fe_mont + rw [hr]; ring + +/-- A.20 `lemma_mont_mul_fe_commute_mont_plain` (F*: Chunk.fst:624). + Mixed mode: `a` Mont, `b` plain, result plain. -/ +@[scoped grind .] +theorem lemma_mont_mul_fe_commute_mont_plain (a b r : Std.I16) + (hr : (r.val : ZMod 3329) = + (a.val : ZMod 3329) * (b.val : ZMod 3329) * 169) : + i16_to_spec_fe_mont a * i16_to_spec_fe_plain b = i16_to_spec_fe_plain r := by + unfold i16_to_spec_fe_mont i16_to_spec_fe_plain + rw [hr]; ring + +/-! ## A.21 — plain multiplication by a constant. -/ + +/-- A.21 `lemma_mul_const_fe_commute_plain` (F*: Chunk.fst:633). + Strict Int-product precondition (no overflow), plain-domain + lift on both sides. -/ +@[scoped grind .] +theorem lemma_mul_const_fe_commute_plain (a c r : Std.I16) + (hr : r.val = a.val * c.val) : + i16_to_spec_fe_plain a * i16_to_spec_fe_plain c = i16_to_spec_fe_plain r := by + unfold i16_to_spec_fe_plain + rw [hr]; push_cast; ring + +/-! ## A.22 — combined inverse-butterfly pair. -/ + +/-- A.22 `lemma_inv_butterfly_pair_commute` (F*: Chunk.fst:588). + Bundles the `add_mont_mod` (lane `i`) and `mul_diff` (lane `j`) + sides of one Gentleman–Sande inverse butterfly. -/ +@[scoped grind .] +theorem lemma_inv_butterfly_pair_commute + (vec_i vec_j result_i result_j zeta : Std.I16) + (hr_i : (result_i.val : ZMod 3329) = + (vec_j.val : ZMod 3329) + (vec_i.val : ZMod 3329)) + (hr_j : (result_j.val : ZMod 3329) = + ((vec_j.val : ZMod 3329) - (vec_i.val : ZMod 3329)) * + (zeta.val : ZMod 3329) * 169) : + i16_to_spec_fe_mont result_i = + i16_to_spec_fe_mont vec_i + i16_to_spec_fe_mont vec_j + ∧ + i16_to_spec_fe_mont result_j = + i16_to_spec_fe_mont zeta * + (i16_to_spec_fe_mont vec_j - i16_to_spec_fe_mont vec_i) := by + refine ⟨?_, ?_⟩ + · -- Reuse A.7 with operands swapped via `add_comm`; A.7's ensures is + -- `mont_fe a + mont_fe b = mont_fe r`, so the goal direction needs `.symm`. + have hr_i' : (result_i.val : ZMod 3329) = + (vec_i.val : ZMod 3329) + (vec_j.val : ZMod 3329) := by + rw [hr_i]; ring + exact (lemma_add_fe_commute_mont_mod vec_i vec_j result_i hr_i').symm + · exact lemma_inv_butterfly_fe_commute_mul_diff vec_i vec_j zeta result_j hr_j + +/-! ## Block B — chunk-level commutes. + + Port of `Hacspec_ml_kem.Commute.Chunk.fst` lines 671–950. Each + Block-B lemma takes the impl post as an explicit per-lane + hypothesis `hr : ∀ j : Fin 16, …` (in lieu of the F* `Operations` + trait `T.f_repr`/`T.f_*` machinery — see M.1 Spec.lean note I.2) + and lifts the corresponding Block-A scalar commute to the + `Vector.ofFn (n := 16) (fun j => …)` shape used by M.4's poly-level + aggregation. + + The shape is uniformly: + Vector.ofFn (lift ∘ getLane r) = Vector.ofFn (combine ∘ lift ∘ getLane lhs ∘ …) + closed by `Vector.ext` + `Vector.getElem_ofFn` + one Block-A apply. + + ### Why no `@[scoped grind]` on Block B + + The Block-A scalar lemmas accept `@[scoped grind]` because + `grind` can pattern on `i16_to_spec_fe_X _` directly in the + conclusion. The Block-B conclusions wrap the lifts inside a + `Vector.ofFn (n := 16) (fun j => i16_to_spec_fe_X ...)`, which puts + the only candidate pattern under a binder; `grind` rejects this + with "failed to find an usable pattern using different modifiers" + regardless of `=`/`←`/`→` modifier or `grind_pattern` (the binders + leave `lhs`/`rhs`/`r` un-instantiable). Block-B lemmas are therefore + consumed explicitly by Block-C / M.4 poly aggregation via + `exact`/`apply` rather than via `grind`. + + B.11–B.14 (compress/decompress chunks) deferred; see comment block + at end of file. +-/ + +/-! ### B.1 / B.2 — pointwise addition (plain and Mont). -/ + +/-- B.1 `lemma_add_chunk_commutes_plain` (F*: Chunk.fst:671). -/ +theorem lemma_add_chunk_commutes_plain + (lhs rhs r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ j : Fin 16, + (r.elements.val[j.val]!).val = + (lhs.elements.val[j.val]!).val + (rhs.elements.val[j.val]!).val) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (lhs.elements.val[j.val]!) + + i16_to_spec_fe_plain (rhs.elements.val[j.val]!)) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact (lemma_add_fe_commute_plain _ _ _ (hr ⟨j, hj⟩)).symm + +/-- B.2 `lemma_add_chunk_commutes_mont` (F*: Chunk.fst:700). -/ +theorem lemma_add_chunk_commutes_mont + (lhs rhs r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ j : Fin 16, + (r.elements.val[j.val]!).val = + (lhs.elements.val[j.val]!).val + (rhs.elements.val[j.val]!).val) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_mont (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_mont (lhs.elements.val[j.val]!) + + i16_to_spec_fe_mont (rhs.elements.val[j.val]!)) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact (lemma_add_fe_commute_mont _ _ _ (hr ⟨j, hj⟩)).symm + +/-! ### B.3 / B.4 — pointwise subtraction (plain and Mont). -/ + +/-- B.3 `lemma_sub_chunk_commutes_plain` (F*: Chunk.fst:729). -/ +theorem lemma_sub_chunk_commutes_plain + (lhs rhs r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ j : Fin 16, + (r.elements.val[j.val]!).val = + (lhs.elements.val[j.val]!).val - (rhs.elements.val[j.val]!).val) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (lhs.elements.val[j.val]!) - + i16_to_spec_fe_plain (rhs.elements.val[j.val]!)) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact (lemma_sub_fe_commute_plain _ _ _ (hr ⟨j, hj⟩)).symm + +/-- B.4 `lemma_sub_chunk_commutes_mont` (F*: Chunk.fst:758). -/ +theorem lemma_sub_chunk_commutes_mont + (lhs rhs r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ j : Fin 16, + (r.elements.val[j.val]!).val = + (lhs.elements.val[j.val]!).val - (rhs.elements.val[j.val]!).val) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_mont (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_mont (lhs.elements.val[j.val]!) - + i16_to_spec_fe_mont (rhs.elements.val[j.val]!)) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact (lemma_sub_fe_commute_mont _ _ _ (hr ⟨j, hj⟩)).symm + +/-! ### B.5 — multiply-by-constant (plain × plain). -/ + +/-- B.5 `lemma_multiply_by_constant_chunk_commutes` (F*: Chunk.fst:790). -/ +theorem lemma_multiply_by_constant_chunk_commutes + (vec r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (c : Std.I16) + (hr : ∀ j : Fin 16, + (r.elements.val[j.val]!).val = + (vec.elements.val[j.val]!).val * c.val) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (vec.elements.val[j.val]!) * + i16_to_spec_fe_plain c) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact (lemma_mul_const_fe_commute_plain _ _ _ (hr ⟨j, hj⟩)).symm + +/-! ### B.6 / B.7 — Montgomery multiply-by-constant. -/ + +/-- B.6 `lemma_montgomery_multiply_by_constant_chunk_commutes_mont_mont` + (F*: Chunk.fst:818). Both operands lifted Mont; result lifted Mont. -/ +theorem lemma_montgomery_multiply_by_constant_chunk_commutes_mont_mont + (vec r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (c : Std.I16) + (hr : ∀ j : Fin 16, + ((r.elements.val[j.val]!).val : ZMod 3329) = + ((vec.elements.val[j.val]!).val : ZMod 3329) * + (c.val : ZMod 3329) * 169) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_mont (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_mont (vec.elements.val[j.val]!) * + i16_to_spec_fe_mont c) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact (lemma_mont_mul_fe_commute_mont_mont _ _ _ (hr ⟨j, hj⟩)).symm + +/-- B.7 `lemma_montgomery_multiply_by_constant_chunk_commutes_mont_plain` + (F*: Chunk.fst:847). `vec` lifted Mont, `c` lifted plain, result + lifted plain. -/ +theorem lemma_montgomery_multiply_by_constant_chunk_commutes_mont_plain + (vec r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (c : Std.I16) + (hr : ∀ j : Fin 16, + ((r.elements.val[j.val]!).val : ZMod 3329) = + ((vec.elements.val[j.val]!).val : ZMod 3329) * + (c.val : ZMod 3329) * 169) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_mont (vec.elements.val[j.val]!) * + i16_to_spec_fe_plain c) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact (lemma_mont_mul_fe_commute_mont_plain _ _ _ (hr ⟨j, hj⟩)).symm + +/-! ### B.8 / B.9 / B.10 — identity-on-plain-lift ops. + + Barrett reduce, conditional `q`-subtract, and "to unsigned + representative" all preserve the residue class mod q. Their chunk + commutes have a simpler shape: both sides of the equation are the + same `Vector.ofFn (i16_to_spec_fe_plain ∘ getLane _)` modulo a + `(r.val : ZMod 3329) = (vec.val : ZMod 3329)` per-lane precond. -/ + +/-- B.8 `lemma_barrett_reduce_chunk_commutes` (F*: Chunk.fst:878). -/ +theorem lemma_barrett_reduce_chunk_commutes + (vec r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ j : Fin 16, + ((r.elements.val[j.val]!).val : ZMod 3329) = + ((vec.elements.val[j.val]!).val : ZMod 3329)) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (vec.elements.val[j.val]!)) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact lemma_barrett_fe_commute _ _ (hr ⟨j, hj⟩) + +/-- B.9 `lemma_cond_subtract_3329_chunk_commutes` (F*: Chunk.fst:902). + Same shape as B.8 — the impl conditionally subtracts q, which is a + no-op mod q. -/ +theorem lemma_cond_subtract_3329_chunk_commutes + (vec r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ j : Fin 16, + ((r.elements.val[j.val]!).val : ZMod 3329) = + ((vec.elements.val[j.val]!).val : ZMod 3329)) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (vec.elements.val[j.val]!)) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact lemma_barrett_fe_commute _ _ (hr ⟨j, hj⟩) + +/-- B.10 `lemma_to_unsigned_representative_chunk_commutes` + (F*: Chunk.fst:925). The impl projects to canonical `[0, q)` + representative — identity mod q. -/ +theorem lemma_to_unsigned_representative_chunk_commutes + (vec r : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ j : Fin 16, + ((r.elements.val[j.val]!).val : ZMod 3329) = + ((vec.elements.val[j.val]!).val : ZMod 3329)) : + Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (r.elements.val[j.val]!)) + = Vector.ofFn (n := 16) (fun (j : Fin 16) => + i16_to_spec_fe_plain (vec.elements.val[j.val]!)) := by + apply Vector.ext + intro j hj + simp only [Vector.getElem_ofFn] + exact lemma_barrett_fe_commute _ _ (hr ⟨j, hj⟩) + +/-! ### B.11–B.14 deferred per arch plan §C.2 / Open Question I.4. + + The compress / decompress chunk commutes + (`lemma_compress_chunk_commutes`, `lemma_decompress_chunk_commutes`, + `lemma_compress_message_chunk_commutes`, + `lemma_decompress_message_chunk_commutes`) are blocked by Open + Question I.4: `HacspecMlKem.compress.compress_d` is + `Result`-monadic, and the lift design (pure-vs-Result return type, + `Vector (Fin (2^d)) 256` vs `Vector (ZMod 3329) 256` shape) is not + pinned down. M.1's `bit_compress` / `bit_decompress` are + placeholder stubs, so any chunk commute stated against them would + be vacuous. They land in a follow-up dispatch once I.4 is resolved. +-/ + +/-! ## Block C — poly-level commutes. + + Port of `Hacspec_ml_kem.Commute.Chunk.fst` lines 1376-2583. Each + Block-C lemma takes the impl post as an explicit per-lane + hypothesis `hr : ∀ i j : Fin 16, …` and conclusion is stated in + BIT-SIDE terms (`bit_` from M.1), not in `HP.` terms (those + are `Result`-monadic in the hacspec spec; a later pass will bridge + `bit_*` ↔ `HP.*`). + + ### `@[scoped grind]` policy (matches Block B). + + Block-C conclusions wrap the lifts inside `to_spec_poly_*` (which + is itself a `Vector.ofFn (n := 256)`); the only candidate pattern + is under a binder, which `grind` rejects. We therefore omit + `@[scoped grind]` and consume these lemmas via explicit + `exact`/`apply` from M.4 poly aggregation. + + The main hammer is `lemma_to_spec_poly_*_eq_of_coeffs` (M.1 spec). + Each Block-C statement reduces to "per-lane Block-A/B lemma gives + the same value on both sides". +-/ + +/-! ### C.1 — Barrett reduce is identity at the poly level. -/ + +/-- C.1 `lemma_poly_barrett_reduce_id` (F*: Chunk.fst:1376). Since + `bit_barrett_reduce p = p` definitionally in M.1, this is `rfl`. -/ +theorem lemma_poly_barrett_reduce_id (p : MontPoly) : + bit_barrett_reduce p = p := rfl + +/-! ### C.2 — Barrett reduce poly commute (per-lane residue ↦ plain + lift identity). -/ + +/-- C.2 `lemma_poly_barrett_reduce_commute` (F*: Chunk.fst:1401). The + per-lane residue equality lifts to the plain-domain poly equality + via `lemma_to_spec_poly_plain_eq_of_coeffs` + per-lane A.5 + (`lemma_barrett_fe_commute`). Combined with C.1 the conclusion can + equivalently be stated as + `to_spec_poly_plain result = to_spec_poly_plain myself`. -/ +theorem lemma_poly_barrett_reduce_commute + (myself result : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ i j : Fin 16, + ((result.coefficients.val[i.val]!).elements.val[j.val]!).val + = ((myself.coefficients.val[i.val]!).elements.val[j.val]!).val + ∨ + (((result.coefficients.val[i.val]!).elements.val[j.val]!).val + : ZMod 3329) + = (((myself.coefficients.val[i.val]!).elements.val[j.val]!).val + : ZMod 3329)) : + to_spec_poly_plain result + = bit_barrett_reduce (to_spec_poly_plain myself) := by + rw [lemma_poly_barrett_reduce_id] + apply lemma_to_spec_poly_plain_eq_of_coeffs + intro i j + rcases hr i j with h | h + · exact lemma_barrett_fe_commute _ _ (by rw [h]) + · exact lemma_barrett_fe_commute _ _ h + +/-! ### C.3 — pointwise addition at the poly level (plain domain). -/ + +/-- C.3 `lemma_add_to_ring_element_commute` (F*: Chunk.fst:1447). Per-lane + strict-add hypothesis lifts to the plain-domain poly equality + `to_spec_poly_plain result = bit_add (to_spec_poly_plain myself) + (to_spec_poly_plain rhs)` via `Vector.ext` + per-lane A.1. + + `maxRecDepth 2000` is required because the per-lane unifier + threads through three nested `Vector.ofFn` bodies (LHS + `to_spec_poly_plain` + two RHS `to_spec_poly_plain` inside + `bit_add`'s `Vector.ofFn`). -/ +theorem lemma_add_to_ring_element_commute + (myself rhs result : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ i j : Fin 16, + ((result.coefficients.val[i.val]!).elements.val[j.val]!).val + = ((myself.coefficients.val[i.val]!).elements.val[j.val]!).val + + ((rhs.coefficients.val[i.val]!).elements.val[j.val]!).val) : + to_spec_poly_plain result + = bit_add (to_spec_poly_plain myself) (to_spec_poly_plain rhs) := by + unfold to_spec_poly_plain bit_add + apply Vector.ext + intro k hk + -- After unfolding both sides we have nested `Vector.ofFn`s. The outer + -- `Vector.ofFn` of `bit_add` indexes at `[k]'hk`; reducing it via + -- `Vector.getElem_ofFn` substitutes `⟨k, hk⟩` into the body, which then + -- contains `(Vector.ofFn _)[⟨k, hk⟩]` for the two `to_spec_poly_plain` + -- arguments. Those Fin-indexed accesses are definitionally `[k]'hk` + -- accesses, so rewrite via `rfl` then fire the simp lemma again. + simp only [Vector.getElem_ofFn] + -- Introduce the Fin-form lemma `(Vector.ofFn f)[⟨k, hk⟩] = (Vector.ofFn f)[k]'hk` + -- as a local hypothesis via `rfl`, since the two forms are definitionally equal. + have my_eq : + (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_plain + ((myself.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[(⟨k, hk⟩ : Fin 256)] + = (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_plain + ((myself.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[k]'hk := rfl + have rhs_eq : + (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_plain + ((rhs.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[(⟨k, hk⟩ : Fin 256)] + = (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_plain + ((rhs.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[k]'hk := rfl + rw [my_eq, rhs_eq] + simp only [Vector.getElem_ofFn] + have hdiv_lt : k / 16 < 16 := by omega + have hmod_lt : k % 16 < 16 := Nat.mod_lt k (by decide) + exact (lemma_add_fe_commute_plain _ _ _ + (hr ⟨k / 16, hdiv_lt⟩ ⟨k % 16, hmod_lt⟩)).symm + +/-! ### C.4 — INTT-Mont finalize core (KEYSTONE). -/ + +/-- C.4 `lemma_intt_mont_form_post` (F*: Chunk.fst:1540). KEYSTONE. The + per-lane INTT-Mont finalize identity: given the INTT-Mont form + precondition `(b.val : ZMod 3329) * 2285 = b_real_val * 128` + (i.e., `b` represents `b_real_val * 128 * R⁻¹` post-INTT) and the + `mont_mul(b, 1441)` post `(r.val : ZMod 3329) = (b.val : ZMod 3329) + * 1441 * 169`, conclude `(r.val : ZMod 3329) = b_real_val`. + + Proof via three keystones (all `by decide`): + - `(1441 * 169 : ZMod 3329) = 512` + - `(2285 * 169 : ZMod 3329) = 1` + - `(128 * 169 * 512 : ZMod 3329) = 1` + plus `ring` glue. -/ +theorem lemma_intt_mont_form_post + (b r : Std.I16) (b_real_val : ZMod 3329) + (hb : (b.val : ZMod 3329) * 2285 = b_real_val * 128) + (hr : (r.val : ZMod 3329) = (b.val : ZMod 3329) * 1441 * 169) : + (r.val : ZMod 3329) = b_real_val := by + have k1 : (1441 * 169 : ZMod 3329) = 512 := by decide + have k2 : (2285 * 169 : ZMod 3329) = 1 := by decide + have k3 : (128 * 169 * 512 : ZMod 3329) = 1 := by decide + -- From hb: multiply both sides by 169. + -- (b.val * 2285) * 169 = (b_real_val * 128) * 169 + -- ⇒ b.val * (2285 * 169) = b_real_val * (128 * 169) + -- ⇒ b.val = b_real_val * 128 * 169 (since 2285·169=1) + have hb2 : (b.val : ZMod 3329) = b_real_val * 128 * 169 := by + have := congrArg (· * (169 : ZMod 3329)) hb + simp only at this + -- this : (b.val * 2285) * 169 = (b_real_val * 128) * 169 + have h1 : (b.val : ZMod 3329) * 2285 * 169 + = (b.val : ZMod 3329) * (2285 * 169) := by ring + rw [h1, k2, mul_one] at this + exact this + -- Now substitute into hr and reduce via k1 and k3. + rw [hr, hb2] + -- Goal: b_real_val * 128 * 169 * 1441 * 169 = b_real_val + have h2 : b_real_val * 128 * 169 * 1441 * 169 + = b_real_val * (128 * 169 * (1441 * 169)) := by ring + rw [h2, k1] + -- Goal: b_real_val * (128 * 169 * 512) = b_real_val + rw [k3, mul_one] + +/-! ### C.5 — Per-lane INTT-Mont finalize wrapper. -/ + +/-- C.5 `lemma_intt_mont_finalize_fe` (F*: Chunk.fst:1666). Per-lane + wrap of C.4: given the same hypotheses, the plain-domain lift + `i16_to_spec_fe_plain r` equals the `b_real_val`. -/ +theorem lemma_intt_mont_finalize_fe + (b r : Std.I16) (b_real_val : ZMod 3329) + (hb : (b.val : ZMod 3329) * 2285 = b_real_val * 128) + (hr : (r.val : ZMod 3329) = (b.val : ZMod 3329) * 1441 * 169) : + i16_to_spec_fe_plain r = b_real_val := by + unfold i16_to_spec_fe_plain + exact lemma_intt_mont_form_post b r b_real_val hb hr + +/-! ### C.7 — to_standard_domain finalize at the FE level. -/ + +/-- C.7 `lemma_to_standard_domain_finalize_fe` (F*: Chunk.fst:2019). + Given the standard-domain form `(myself.val : ZMod 3329) * 2285 + = plain_real_val` (i.e., `myself` represents `α · R⁻¹`) and the + `mont_mul(myself, 1353)` post `(r.val : ZMod 3329) = (myself.val + : ZMod 3329) * 1353 * 169`, conclude `i16_to_spec_fe_mont r + = plain_real_val * 2285` (the "Mont-lift of `r` recovers `α · R`"). + + Note: we state the conclusion via `i16_to_spec_fe_mont` (×169) on + the Mont domain lift. The keystone `(1353 * 169 : ZMod 3329) = 2285` + (R² · R⁻¹ = R) combined with the precondition gives the result. -/ +theorem lemma_to_standard_domain_finalize_fe + (myself r : Std.I16) (plain_real_val : ZMod 3329) + (hm : (myself.val : ZMod 3329) * 2285 = plain_real_val) + (hr : (r.val : ZMod 3329) + = (myself.val : ZMod 3329) * 1353 * 169) : + i16_to_spec_fe_plain r = plain_real_val := by + have k1 : (1353 * 169 : ZMod 3329) = 2285 := by decide + unfold i16_to_spec_fe_plain + rw [hr] + -- Goal: myself.val * 1353 * 169 = plain_real_val + have h1 : (myself.val : ZMod 3329) * 1353 * 169 + = (myself.val : ZMod 3329) * (1353 * 169) := by ring + rw [h1, k1, hm] + +/-! ### C.8 — Mont form post (standard-domain analogue of C.4). -/ + +/-- C.8 `lemma_mont_form_post` (F*: Chunk.fst:1943). Analogous to C.4 + but for the standard-domain (matrix-mul track) form. Given + `(myself.val : ZMod 3329) * 2285 = plain_real_val` (standard-domain + form) and `(r.val : ZMod 3329) = (myself.val : ZMod 3329) * 1353 + * 169` (mont_mul-by-1353), conclude `(r.val : ZMod 3329) + = plain_real_val`. + + Keystone: `(1353 * 169 : ZMod 3329) = 2285`. -/ +theorem lemma_mont_form_post + (myself r : Std.I16) (plain_real_val : ZMod 3329) + (hm : (myself.val : ZMod 3329) * 2285 = plain_real_val) + (hr : (r.val : ZMod 3329) + = (myself.val : ZMod 3329) * 1353 * 169) : + (r.val : ZMod 3329) = plain_real_val := by + have k1 : (1353 * 169 : ZMod 3329) = 2285 := by decide + rw [hr] + have h1 : (myself.val : ZMod 3329) * 1353 * 169 + = (myself.val : ZMod 3329) * (1353 * 169) := by ring + rw [h1, k1, hm] + +/-! ## Block C Tier 2 — poly-level lemmas with composite ops. + + These three lemmas extend Block C with the "composite" poly-level + commutes that thread through more than one impl op (subtract + + mont_mul-by-1441, add + mont_mul-by-1353 + barrett, and the + `createi`-equality bridge that pairs with C.6). + + ### Framing decision: post-keystone form on the impl precondition. + + For C.6, the impl chain is `result_lane = mont_mul(rhs - myself, 1441)`, + which in `ZMod 3329` collapses to `result_lane.val = (rhs.val - + myself.val) · 1441 · 169 = (rhs.val - myself.val) · 512` (via the + C.4 keystone `1441 · 169 = 512`). We state the per-lane impl + precondition in the **post-keystone form** `result.val = + (rhs.val - myself.val) · 512` so the conclusion against M.1's + `bit_subtract_reduce` (whose body is `(q[i] - p[i]) · 512` on + `MontPoly`) reduces by pure `ring`. Callers chain through C.4 first. + + For C.9, the impl chain is `result_lane = barrett (myself + error)`, + with the per-lane post being a `ZMod 3329` residue equality. We + state it directly as `result.val ≡ myself.val + error.val (mod q)` + and conclude `to_spec_poly_plain result = bit_add_to_ring_element + (to_spec_poly_plain myself) (to_spec_poly_plain error)`. + + For C.10, this is the `Vector.ofFn`-equality bridge — given two + impl polys with equal coefficients, the scaled-by-1441 createis + coincide. In Lean this is essentially `congrArg` on top of + `lemma_to_spec_poly_mont_eq_of_coeffs`. -/ + +/-! ### C.6 — subtract-reduce poly commute (post-keystone form). -/ + +/-- C.6 `lemma_subtract_reduce_commute` (F*: Chunk.fst:1852). Poly-level + commute for the subtract-then-finalize-INTT chain. + + Per-lane impl precondition is stated in **post-keystone form**: the + impl's `mont_mul(rhs - myself, 1441)` already collapses to + `(rhs.val - myself.val) · 512` in `ZMod 3329` via the C.4 keystone + `(1441 · 169 : ZMod 3329) = 512`. Callers apply C.4 once per + lane and feed the post-keystone equality here, avoiding redoing the + 1441-keystone chain inside this lemma. + + Conclusion is stated against M.1's `bit_subtract_reduce` (which is + itself `(q[i] - p[i]) · 512` on `MontPoly`) lifted through + `to_spec_poly_mont` on both sides. -/ +theorem lemma_subtract_reduce_commute + (myself rhs result : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ i j : Fin 16, + ((result.coefficients.val[i.val]!).elements.val[j.val]!).val + = (((rhs.coefficients.val[i.val]!).elements.val[j.val]!).val + - ((myself.coefficients.val[i.val]!).elements.val[j.val]!).val) + * 512) : + to_spec_poly_mont result + = bit_subtract_reduce (to_spec_poly_mont myself) (to_spec_poly_mont rhs) := by + unfold to_spec_poly_mont bit_subtract_reduce + apply Vector.ext + intro k hk + -- Strip the outer `Vector.ofFn` on both sides; LHS body is + -- `i16_to_spec_fe_mont (result.coef[k/16].elt[k%16])` + -- and RHS body is + -- `((to_spec_poly_mont rhs)[k] - (to_spec_poly_mont myself)[k]) * 512`. + simp only [Vector.getElem_ofFn] + -- The two `(Vector.ofFn _)[⟨k, hk⟩]` accesses inside the RHS body are + -- definitionally `[k]'hk` accesses; identify the two `Fin`-form + -- accesses to the plain `[k]'hk` form (same trick as C.3). + have my_eq : + (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_mont + ((myself.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[(⟨k, hk⟩ : Fin 256)] + = (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_mont + ((myself.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[k]'hk := rfl + have rhs_eq : + (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_mont + ((rhs.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[(⟨k, hk⟩ : Fin 256)] + = (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_mont + ((rhs.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[k]'hk := rfl + rw [my_eq, rhs_eq] + simp only [Vector.getElem_ofFn] + -- Goal: i16_to_spec_fe_mont result.coef[k/16].elt[k%16] + -- = (i16_to_spec_fe_mont rhs.coef[k/16].elt[k%16] + -- - i16_to_spec_fe_mont myself.coef[k/16].elt[k%16]) * 512 + -- Unfold the Mont lift to expose the `· 169` factor on each lane, + -- then substitute `hr` on the LHS and close with `ring`. + unfold i16_to_spec_fe_mont + have hdiv_lt : k / 16 < 16 := by omega + have hmod_lt : k % 16 < 16 := Nat.mod_lt k (by decide) + have h := hr ⟨k / 16, hdiv_lt⟩ ⟨k % 16, hmod_lt⟩ + -- `h : result.val = (rhs.val - myself.val) * 512` as `Int`. + -- Cast to `ZMod 3329` and combine with `ring`. + have hz : + ((((result.coefficients.val[k / 16]!).elements.val[k % 16]!).val : ZMod 3329)) + = ((((rhs.coefficients.val[k / 16]!).elements.val[k % 16]!).val : ZMod 3329) + - (((myself.coefficients.val[k / 16]!).elements.val[k % 16]!).val : ZMod 3329)) + * 512 := by + have := congrArg (Int.cast (R := ZMod 3329)) h + push_cast at this + exact this + rw [hz]; ring + +/-! ### C.9 — add-standard-error-reduce poly commute. -/ + +/-- C.9 `lemma_add_standard_error_reduce_commute` (F*: Chunk.fst:2135). + Poly-level commute for the `add + barrett` chain. Per-lane impl + precondition is the residue equality + `result.val ≡ myself.val + error.val (mod q)` (the impl's barrett + reduction collapses to identity at the `ZMod` level — A.7's shape). + + Conclusion: `to_spec_poly_plain result = bit_add_to_ring_element + (to_spec_poly_plain myself) (to_spec_poly_plain error)`, where + `bit_add_to_ring_element = bit_add` is the M.1 pointwise add on + `MontPoly`. -/ +theorem lemma_add_standard_error_reduce_commute + (myself error result : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hr : ∀ i j : Fin 16, + (((result.coefficients.val[i.val]!).elements.val[j.val]!).val + : ZMod 3329) + = (((myself.coefficients.val[i.val]!).elements.val[j.val]!).val + : ZMod 3329) + + (((error.coefficients.val[i.val]!).elements.val[j.val]!).val + : ZMod 3329)) : + to_spec_poly_plain result + = bit_add_to_ring_element + (to_spec_poly_plain myself) (to_spec_poly_plain error) := by + unfold bit_add_to_ring_element bit_add to_spec_poly_plain + apply Vector.ext + intro k hk + simp only [Vector.getElem_ofFn] + have my_eq : + (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_plain + ((myself.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[(⟨k, hk⟩ : Fin 256)] + = (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_plain + ((myself.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[k]'hk := rfl + have err_eq : + (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_plain + ((error.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[(⟨k, hk⟩ : Fin 256)] + = (Vector.ofFn fun (j : Fin 256) => + i16_to_spec_fe_plain + ((error.coefficients.val[j.val / 16]!).elements.val[j.val % 16]!))[k]'hk := rfl + rw [my_eq, err_eq] + simp only [Vector.getElem_ofFn] + -- Goal: i16_to_spec_fe_plain result.coef[k/16].elt[k%16] + -- = i16_to_spec_fe_plain myself.coef[k/16].elt[k%16] + -- + i16_to_spec_fe_plain error.coef[k/16].elt[k%16] + unfold i16_to_spec_fe_plain + have hdiv_lt : k / 16 < 16 := by omega + have hmod_lt : k % 16 < 16 := Nat.mod_lt k (by decide) + exact hr ⟨k / 16, hdiv_lt⟩ ⟨k % 16, hmod_lt⟩ + +/-! ### C.10 — `Vector.ofFn`-equality bridge for the C.6 conclusion. -/ + +/-- C.10 `lemma_subtract_reduce_scaled_eq` (F*: Chunk.fst:2533). The F* + version exists to paper over Z3 not auto-deriving equality of two + `createi`s from equality of their per-lane bodies. In Lean, + `Vector.ofFn` already enjoys congruence under `funext`, so once we + know the inner `to_spec_poly_mont` lifts coincide (by + `lemma_to_spec_poly_mont_eq_of_coeffs`), the outer scaled-by-1441 + `Vector.ofFn`s coincide by `congrArg`. + + Stated against the M.1 idiom (pointwise multiply via `*` in + `ZMod 3329`, not the F* `impl_FieldElement__mul`). -/ +theorem lemma_subtract_reduce_scaled_eq + (p q : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h : ∀ i j : Fin 16, + i16_to_spec_fe_mont + ((p.coefficients.val[i.val]!).elements.val[j.val]!) + = i16_to_spec_fe_mont + ((q.coefficients.val[i.val]!).elements.val[j.val]!)) : + (Vector.ofFn (n := 256) fun (j : Fin 256) => + (to_spec_poly_mont p)[j.val]'j.isLt * (1441 : ZMod 3329)) + = (Vector.ofFn (n := 256) fun (j : Fin 256) => + (to_spec_poly_mont q)[j.val]'j.isLt * (1441 : ZMod 3329)) := by + have hpq : to_spec_poly_mont p = to_spec_poly_mont q := + lemma_to_spec_poly_mont_eq_of_coeffs p q h + rw [hpq] + +end libcrux_iot_ml_kem.Spec.Commute \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Lift.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Lift.lean new file mode 100644 index 00000000..eaa4df4a --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Lift.lean @@ -0,0 +1,1233 @@ +/- + # `Spec/Lift.lean` — extracted from `FCTargets.lean` §lift. +-/ +import LibcruxIotMlKem.Spec +import LibcruxIotMlKem.Spec.Pure +import LibcruxIotMlKem.Spec.AlgEquiv +import LibcruxIotMlKem.Spec.ModularArith +import LibcruxIotMlKem.Extraction.Funs +import HacspecMlKem.Extraction.Funs + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.Spec.Lift +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §0 Lift tower + + Each `lift_*` projects an impl-side carrier to the corresponding + hacspec carrier. Type signatures are load-bearing — they are what + the FC equation reads on both sides. Bodies use existing M.1 + pieces (`i16_to_spec_fe_mont`, `feOfZMod`, `to_spec_poly_mont`) + where convenient. -/ + +/-- Default `FieldElement` used by `[i]!` projections inside the + lift bodies below. The canonical residue 0 mod q. -/ +noncomputable def defaultFE : + hacspec_ml_kem.parameters.FieldElement := + feOfZMod (0 : ZMod 3329) + +private noncomputable instance : Inhabited hacspec_ml_kem.parameters.FieldElement := + ⟨defaultFE⟩ + +/-- Local `Inhabited` instance for `PortableVector` used by `[i]!` + indexing in `lift_chunk` / `lift_poly`. Mirrors the `local instance` + in `Spec.lean` (which is file-scoped). -/ +private instance instInhabitedPortableVector_fcTargets : + Inhabited libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + ⟨{ elements := Std.Array.make 16#usize (List.replicate 16 (0#i16 : Std.I16)) + (by simp) }⟩ + +/-- Local `Inhabited` instance for `PolynomialRingElement PortableVector` + used by `[i]!` indexing in `lift_poly` / `lift_vec_slice`. -/ +private instance instInhabitedPolynomialRingElement_fcTargets + {Vector : Type} [Inhabited Vector] : + Inhabited (libcrux_iot_ml_kem.polynomial.PolynomialRingElement Vector) := + ⟨{ coefficients := + Std.Array.make 16#usize (List.replicate 16 default) (by simp) }⟩ + +/-- Plain-domain lane lift from `Int` to a hacspec `FieldElement`. + Used by `barrett_reduce_element_fc` (the impl carries the value + in plain domain). -/ +noncomputable def lift_fe_int (x : Int) : hacspec_ml_kem.parameters.FieldElement := + feOfZMod (x : ZMod 3329) + +/-- Plain-domain lane lift from `Std.I16` to a hacspec `FieldElement`. + Composes `i16_to_spec_fe_plain` with `feOfZMod`. -/ +noncomputable def lift_fe (lane : Std.I16) : hacspec_ml_kem.parameters.FieldElement := + feOfZMod (i16_to_spec_fe_plain lane) + +/-- Mont-domain lane lift from `Std.I16` to a hacspec `FieldElement`. + Used for outputs of impl ops that produce Mont-form lanes + (`montgomery_multiply_*`, `montgomery_reduce_element`). -/ +noncomputable def lift_fe_mont (lane : Std.I16) : hacspec_ml_kem.parameters.FieldElement := + feOfZMod (i16_to_spec_fe_mont lane) + +/-- Plain-domain poly lift `PortableVector chunk → 16 FE-array`. + Maps each of the 16 lanes through `lift_fe`. -/ +noncomputable def lift_chunk + (chunk : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize (chunk.elements.val.map lift_fe) (by + simp []) + +/-- Mont-domain poly lift `PortableVector chunk → 16 FE-array`. + Maps each of the 16 lanes through `lift_fe_mont`. -/ +noncomputable def lift_chunk_mont + (chunk : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize (chunk.elements.val.map lift_fe_mont) (by + simp []) + +/-- Plain-domain poly lift: `PolynomialRingElement PortableVector → + Array FE 256`. The result is the hacspec "ring element" type. + Flattens 16 chunks × 16 lanes via the standard + `i = j / 16`, `k = j % 16` decomposition. -/ +noncomputable def lift_poly + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Std.Array.make 256#usize + ((List.range 256).map (fun j => + lift_fe (re.coefficients.val[j / 16]!).elements.val[j % 16]!)) + (by simp) + +/-- Mont-domain poly lift. Same shape as `lift_poly` but strips one + `R` factor per lane via `i16_to_spec_fe_mont`. -/ +noncomputable def lift_poly_mont + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Std.Array.make 256#usize + ((List.range 256).map (fun j => + lift_fe_mont (re.coefficients.val[j / 16]!).elements.val[j % 16]!)) + (by simp) + +/-- Vector lift: `Array (PolynomialRingElement) K → Array (Array FE 256) K`. -/ +noncomputable def lift_vec {K : Std.Usize} + (v : Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) : + Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + Std.Array.make K (v.val.map lift_poly) (by + simp []) + +/-- Vector-slice variant for `Slice`-typed impl args + (e.g. `compute_ring_element_v` takes `r_as_ntt : Slice ...`). + The FC theorems that consume this expect `v.length = K.val` as a + precondition; out-of-range indices default to the unit chunk. -/ +noncomputable def lift_vec_slice + (v : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (K : Std.Usize) : + Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + Std.Array.make K + ((List.range K.val).map (fun i => lift_poly v.val[i]!)) + (by simp) + +/-- Plain-domain lift from a 256-lane `Std.I32` accumulator to a + `FieldElement` poly. Each lane goes through `lift_fe_int` on its + `.val` (Int). Used by the L6c NTT-multiply family FC equations to + relate the impl-side I32 accumulator to a `FieldElement 256`-array. + Matches the `Spec.poly_reducing_from_i32_array_pure` lane shape — composes cleanly with L6.7. -/ +noncomputable def lift_accumulator_i32 + (acc : Std.Array Std.I32 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Std.Array.make 256#usize + ((List.range 256).map (fun i => lift_fe_int (acc.val[i]!).val)) + (by simp) + +/-- Matrix lift: `Array (Array (PolynomialRingElement) K) K → Array (Array (Array FE 256) K) K`. -/ +noncomputable def lift_matrix {K : Std.Usize} + (m : Std.Array + (Std.Array + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) K) K) : + Std.Array + (Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) K := + Std.Array.make K (m.val.map lift_vec) (by + simp []) + +/-- Pure projection of `matrix.sample_matrix_A` from the public-key seed. + Forward-declared here (rather than in §0.5 below) so + `lift_matrix_from_seed` can reference it. + + Pending pure-projection side lemma: + `hacspec_ml_kem.matrix.sample_matrix_A seed K + = .ok (Spec.sample_matrix_A_pure seed K)`. -/ +noncomputable opaque Spec.sample_matrix_A_pure + (seed : Slice Std.U8) (K : Std.Usize) : + Std.Array + (Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) K + +/-- Matrix-from-seed lift: the impl `matrix.compute_vector_u` reconstructs + the matrix in-place via `sample_matrix_entry`; the hacspec spec calls + `matrix.sample_matrix_A` on the seed once at the top. Defers to + `Spec.sample_matrix_A_pure` above for the deterministic projection. -/ +noncomputable def lift_matrix_from_seed + (seed : Slice Std.U8) (K : Std.Usize) : + Std.Array + (Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) K := + Spec.sample_matrix_A_pure seed K + +/-- Matrix-from-flat-slice lift: the impl `matrix.compute_As_plus_e` takes + `matrix_A : Slice (PolynomialRingElement)` as a flat K·K slice in + row-major order (impl convention: `matrix_A[i*K+j]` is the + (row `i`, column `j`) entry). We reshape it into a 2D K×K matrix using + FIPS 203's column-major convention — "a matrix is a set of column + vectors" (`specs/ml-kem/src/matrix.rs:8-9`) — so the outer index is + the column and the inner index is the row: + `(lift_matrix_from_slice slice K).val[j]!.val[i]! + = lift_poly slice.val[i * K.val + j]!`. + This matches how hacspec's `multiply_matrix_by_column_at` accesses + `m[j][i]` (column-major). Used by L7.1's locked POST. Requires the + caller's `matrix_A.length = K.val * K.val` precondition for the + indexing to be in-range (out-of-range indices default to the unit poly + via the `Inhabited` instance). -/ +noncomputable def lift_matrix_from_slice + (slice : Slice + (libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + (K : Std.Usize) : + Std.Array + (Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) K := + Std.Array.make K + ((List.range K.val).map (fun j => + Std.Array.make K + ((List.range K.val).map (fun i => + lift_poly slice.val[i * K.val + j]!)) + (by simp))) + (by simp) + +/-- Pure projection of the public-key deserialization producing + `t_as_ntt : Array (Array FE 256) K`. The impl `matrix.compute_ring_element_v` + consumes `public_key : Slice U8` via `chunks_exact public_key + BYTES_PER_RING_ELEMENT`, deserializing each chunk into a ring element. + Used by L7.3's locked post. Declared `opaque` here; the explicit + deserialization spec is a pending obligation paralleling + `Spec.sample_matrix_A_pure`. -/ +noncomputable opaque Spec.t_as_ntt_from_public_key_pure + (public_key : Slice Std.U8) (K : Std.Usize) : + Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K + +/-- Public-key-bytes lift wrapping `Spec.t_as_ntt_from_public_key_pure`. + The impl `matrix.compute_ring_element_v` deserializes `public_key` into + a vector of ring elements; the hacspec spec receives this vector + pre-deserialized as its first argument. -/ +noncomputable def lift_t_as_ntt_from_public_key + (public_key : Slice Std.U8) (K : Std.Usize) : + Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K := + Spec.t_as_ntt_from_public_key_pure public_key K + +/-! ## §0.5 Spec `_pure` aliases needed beyond `Spec.Pure.lean`. + + `Spec.Pure.lean` already provides: + - `FieldElement.{add,sub,mul,neg}_pure` + - `polynomial.{add_to_ring_element,poly_barrett_reduce,subtract_reduce}_pure` + + We add here the missing `_pure` aliases referenced by FC equations + below. Each is the `Result`-stripped pure projection of a + `Result`-monadic hacspec op; bodies use the standard + `match | .ok r => r | _ => default` + pattern (see `Spec.Pure.lean`). Bodies left `sorry` here for brevity + — types are load-bearing. -/ + +/-- Pure projection of `parameters.FieldElement.new (x.val % q)` — + the canonical-residue constructor, here re-expressed as the round-trip + `feOfZMod ∘ zmodOfFE`. The two forms are equivalent: both produce + `{ val := ⟨BitVec.ofNat 16 (x.val.val % 3329)⟩ }` since `zmodOfFE x` + is `(x.val.val : ZMod 3329)` (whose underlying Nat is `x.val.val % 3329`) + and `parameters.FieldElement.new` always returns `.ok ⟨_⟩` unconditionally. + The round-trip form composes with the existing `zmodOfFE_feOfZMod` + identity in M.1, making the FC equation reduce to "lift_fe r = lift_fe value + given r ≡ value mod q". -/ +noncomputable def Spec.barrett_pure (x : hacspec_ml_kem.parameters.FieldElement) : + hacspec_ml_kem.parameters.FieldElement := + feOfZMod (zmodOfFE x) + +/-- Pure projection of Montgomery reduction at the FE level. The impl + `montgomery_reduce_element` takes an `Std.I32` and returns an `Std.I16` + in Mont domain (encoding `a · R`). The hacspec spec has no direct + counterpart at the FE level. The FC equation + `lift_fe_mont r = Spec.mont_reduce_pure (lift_fe_int value.val)` + combines two factors of R⁻¹: + (i) the impl's invariant `r ≡ value · R⁻¹ (mod q)`, and + (ii) `lift_fe_mont`'s own R-stripping (it returns `(r.val : ZMod 3329) · 169`). + The TOTAL effect is `value.val · R⁻² mod q`. Since `R⁻¹ = 169 mod q`, + `R⁻² = 169² mod q`. So `Spec.mont_reduce_pure` multiplies its + ZMod-projected argument by `169 · 169`. -/ +noncomputable def Spec.mont_reduce_pure (x : hacspec_ml_kem.parameters.FieldElement) : + hacspec_ml_kem.parameters.FieldElement := + feOfZMod (zmodOfFE x * 169 * 169) + +/-- Pure projection of Montgomery `fe · fer / R`: given two FEs, returns + `fe · fer · R⁻¹` in canonical domain (i.e., `zmodOfFE fe · zmodOfFE fer · 169` + in ZMod 3329). The factor `169 = R⁻¹ mod q` comes from the impl's + Montgomery reduction step (the L0.3 calculation gave 169² because the + INPUT was plain-via-`lift_fe_int`, whereas here `fer` is already + interpreted in Mont domain through `lift_fe_mont`, so only ONE R⁻¹ + factor is needed). The math intent of the impl: given fe (plain, math + value = fe) and fer (Mont, math value = fer · R⁻¹), output Mont-encoded + fe · (fer · R⁻¹) = fe · fer · R⁻¹ in Mont. The Mont encoding is then + stripped by `lift_fe_mont`, giving the canonical math value + fe · fer · R⁻¹. -/ +noncomputable def Spec.montgomery_multiply_fe_by_fer_pure + (fe fer : hacspec_ml_kem.parameters.FieldElement) : + hacspec_ml_kem.parameters.FieldElement := + feOfZMod (zmodOfFE fe * zmodOfFE fer * 169) + +/-- Pure projection of `get_n_least_significant_bits` — pure modular + truncation on `Std.U32`. -/ +def Spec.get_n_least_significant_bits_pure (n : Std.U8) (value : Std.U32) : Std.U32 := + ⟨value.bv &&& ((1#32 <<< n.val) - 1#32)⟩ + +/-- Pure pointwise add at the FE-array level (16-lane chunk). + Lifts `FieldElement.add_pure` across the 16 lanes via `List.range 16`. -/ +noncomputable def Spec.chunk_add_pure + (a b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize + ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (a.val[i]!) (b.val[i]!))) + (by simp) + +/-- Pure pointwise sub at the FE-array level (16-lane chunk). + Lifts `FieldElement.sub_pure` across the 16 lanes. -/ +noncomputable def Spec.chunk_sub_pure + (a b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize + ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[i]!) (b.val[i]!))) + (by simp) + +/-- Pure pointwise neg at the FE-array level (16-lane chunk). + Lifts `FieldElement.neg_pure` across the 16 lanes. -/ +noncomputable def Spec.chunk_neg_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize + ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure + (a.val[i]!))) + (by simp) + +/-- Pure pointwise barrett-reduce at the FE-array level. + Lifts `Spec.barrett_pure` (the canonical round-trip) across 16 lanes. -/ +noncomputable def Spec.chunk_barrett_reduce_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize + ((List.range 16).map (fun i => + Spec.barrett_pure (a.val[i]!))) + (by simp) + +/-- Pure pointwise `montgomery_multiply_by_constant` at the chunk level + (each lane: `fe · c / R`). Lifts `Spec.montgomery_multiply_fe_by_fer_pure` + across 16 lanes, with the second arg threaded as the constant `c`. -/ +noncomputable def Spec.chunk_montgomery_multiply_by_constant_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (c : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize + ((List.range 16).map (fun i => + Spec.montgomery_multiply_fe_by_fer_pure (a.val[i]!) c)) + (by simp) + +/-- Pure pointwise plain `multiply_by_constant` at the chunk level. + Lifts `FieldElement.mul_pure` across 16 lanes with the constant `c`. -/ +noncomputable def Spec.chunk_multiply_by_constant_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (c : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize + ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (a.val[i]!) c)) + (by simp) + +/-- Pure pointwise `bitwise_and_with_constant` at the chunk level. + NO HACSPEC EQUIVALENT — this is a bit-level mask used only in + serialize/compress paths. The body applies BV-and on each FE's + underlying `U16` BV. + + WARNING (FC obstruction): the FC equation for `bitwise_and_with_constant_fc` + against `lift_chunk`-style inputs is NOT provable in general because + `lift_chunk` discards the bit pattern (keeping only mod-3329 residue), + while bit-level AND depends on the raw `I16` bit pattern. The body here + is the canonical FE-side BV operation; the FC proof will STOP and report + when attempted. Not on the L7 critical path (used only in compress/ + serialize, which lives outside the 4 matrix-level targets). -/ +noncomputable def Spec.chunk_bitwise_and_with_constant_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (c : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize + ((List.range 16).map (fun i => + let ai_bv : BitVec Aeneas.Std.UScalarTy.U16.numBits := (a.val[i]!).val.bv + let c_bv : BitVec Aeneas.Std.UScalarTy.U16.numBits := c.val.bv + ({ val := { bv := ai_bv &&& c_bv } } : hacspec_ml_kem.parameters.FieldElement))) + (by simp) + +/-- Pure pointwise `shift_right` at the chunk level. + NO HACSPEC EQUIVALENT at the FE level. The body applies a logical + right shift on each FE's underlying `U16` BV by `SHIFT_BY.val.toNat`. + + WARNING (FC obstruction): same as `chunk_bitwise_and_with_constant_pure` + — the FC equation is not provable through `lift_chunk` because the + underlying `I16` sshiftRight depends on raw bit pattern. The body here + serves as a placeholder; the FC proof will STOP and report. Not on + the L7 critical path. -/ +noncomputable def Spec.chunk_shift_right_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (SHIFT_BY : Std.I32) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize + ((List.range 16).map (fun i => + let ai_bv : BitVec Aeneas.Std.UScalarTy.U16.numBits := (a.val[i]!).val.bv + let shift : Nat := SHIFT_BY.val.toNat + ({ val := { bv := ai_bv >>> shift } } : hacspec_ml_kem.parameters.FieldElement))) + (by simp) + +/-- Pure `reducing_from_i32_array` at the chunk level. Lifts `Spec.mont_reduce_pure` + over 16 lanes of the input `i32` slice. Each lane: take `array[i]`, + project through `lift_fe_int`, apply Montgomery reduction. -/ +noncomputable def Spec.chunk_reducing_from_i32_array_pure + (array : Slice Std.I32) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize + ((List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (array.val[i]!).val))) + (by simp) + +/-! ### §M.1 — Per-lane unfolds for `Spec.chunk_*_pure`. + + Direct lane projections for the chunk-level pointwise operations. + Save ~30-50 LOC per proof that needs to extract a specific lane + from a chunk-pure result (e.g. L6.3 step lemma, L7 row composition, + L6.3c cache-variant wrap). Each lemma collapses the + `Std.Array.make 16#usize ((List.range 16).map ...)` + `[k]!` + `List.getElem_map` + + `List.getElem_range` cascade into a single rewrite. -/ + +/-- Lane projection of `Spec.chunk_add_pure`. -/ +theorem Spec.chunk_add_pure_lane_eq + (a b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (k : Nat) (hk : k < 16) : + (Spec.chunk_add_pure a b).val[k]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (a.val[k]!) (b.val[k]!) := by + unfold Spec.chunk_add_pure + show ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (a.val[i]!) (b.val[i]!)))[k]! = _ + have h_l : ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (a.val[i]!) (b.val[i]!))).length = 16 := by simp + rw [getElem!_pos _ k (by rw [h_l]; exact hk)] + rw [List.getElem_map, List.getElem_range] + +/-- Lane projection of `Spec.chunk_sub_pure`. -/ +theorem Spec.chunk_sub_pure_lane_eq + (a b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (k : Nat) (hk : k < 16) : + (Spec.chunk_sub_pure a b).val[k]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[k]!) (b.val[k]!) := by + unfold Spec.chunk_sub_pure + show ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[i]!) (b.val[i]!)))[k]! = _ + have h_l : ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (a.val[i]!) (b.val[i]!))).length = 16 := by simp + rw [getElem!_pos _ k (by rw [h_l]; exact hk)] + rw [List.getElem_map, List.getElem_range] + +/-- Lane projection of `Spec.chunk_neg_pure`. -/ +theorem Spec.chunk_neg_pure_lane_eq + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (k : Nat) (hk : k < 16) : + (Spec.chunk_neg_pure a).val[k]! + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure + (a.val[k]!) := by + unfold Spec.chunk_neg_pure + show ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure + (a.val[i]!)))[k]! = _ + have h_l : ((List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure + (a.val[i]!))).length = 16 := by simp + rw [getElem!_pos _ k (by rw [h_l]; exact hk)] + rw [List.getElem_map, List.getElem_range] + +/-- Lane projection of `Spec.chunk_reducing_from_i32_array_pure`. -/ +theorem Spec.chunk_reducing_from_i32_array_pure_lane_eq + (array : Slice Std.I32) (k : Nat) (hk : k < 16) : + (Spec.chunk_reducing_from_i32_array_pure array).val[k]! + = Spec.mont_reduce_pure (lift_fe_int (array.val[k]!).val) := by + unfold Spec.chunk_reducing_from_i32_array_pure + show ((List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (array.val[i]!).val)))[k]! = _ + have h_l : ((List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (array.val[i]!).val))).length = 16 := by simp + rw [getElem!_pos _ k (by rw [h_l]; exact hk)] + rw [List.getElem_map, List.getElem_range] + +/-- Pure NTT butterfly step at the chunk level: applies `ntt.butterfly` + pointwise to the lane pair `(i, j)` with `zeta`. Mirrors the impl's + write order (`a[j] := a-t`, then `a[i] := a+t`) so that when `i = j` + the second write wins (matching impl semantics). When `i ≠ j` the + `(i, j)` lanes become `(add_pure a[i] (mul_pure a[j] zeta), + sub_pure a[i] (mul_pure a[j] zeta))` respectively. -/ +noncomputable def Spec.chunk_ntt_step_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (zeta : hacspec_ml_kem.parameters.FieldElement) (i j : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let t_fe := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure (a.val[j.val]!) zeta + let a_minus_t := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure (a.val[i.val]!) t_fe + let a_plus_t := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure (a.val[i.val]!) t_fe + (a.set j a_minus_t).set i a_plus_t + +/-- Pure NTT-layer-1 step at the chunk level. Mirrors the impl's + 8 sequential `ntt_step` calls at pairs (0,2)(1,3)(4,6)(5,7) + (8,10)(9,11)(12,14)(13,15) with zetas z0,z0,z1,z1,z2,z2,z3,z3. -/ +noncomputable def Spec.chunk_ntt_layer_1_step_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 z2 z3 : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let a1 := Spec.chunk_ntt_step_pure a z0 0#usize 2#usize + let a2 := Spec.chunk_ntt_step_pure a1 z0 1#usize 3#usize + let a3 := Spec.chunk_ntt_step_pure a2 z1 4#usize 6#usize + let a4 := Spec.chunk_ntt_step_pure a3 z1 5#usize 7#usize + let a5 := Spec.chunk_ntt_step_pure a4 z2 8#usize 10#usize + let a6 := Spec.chunk_ntt_step_pure a5 z2 9#usize 11#usize + let a7 := Spec.chunk_ntt_step_pure a6 z3 12#usize 14#usize + Spec.chunk_ntt_step_pure a7 z3 13#usize 15#usize + +/-- Pure NTT-layer-2 step at the chunk level. Mirrors the impl's + 8 sequential `ntt_step` calls at pairs (0,4)(1,5)(2,6)(3,7) + (8,12)(9,13)(10,14)(11,15) with zetas z0,z0,z0,z0,z1,z1,z1,z1. -/ +noncomputable def Spec.chunk_ntt_layer_2_step_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let a1 := Spec.chunk_ntt_step_pure a z0 0#usize 4#usize + let a2 := Spec.chunk_ntt_step_pure a1 z0 1#usize 5#usize + let a3 := Spec.chunk_ntt_step_pure a2 z0 2#usize 6#usize + let a4 := Spec.chunk_ntt_step_pure a3 z0 3#usize 7#usize + let a5 := Spec.chunk_ntt_step_pure a4 z1 8#usize 12#usize + let a6 := Spec.chunk_ntt_step_pure a5 z1 9#usize 13#usize + let a7 := Spec.chunk_ntt_step_pure a6 z1 10#usize 14#usize + Spec.chunk_ntt_step_pure a7 z1 11#usize 15#usize + +/-- Pure NTT-layer-3 step at the chunk level. Mirrors the impl's + 8 sequential `ntt_step` calls at pairs (0,8)(1,9)(2,10)(3,11) + (4,12)(5,13)(6,14)(7,15) all with the same zeta. -/ +noncomputable def Spec.chunk_ntt_layer_3_step_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let a1 := Spec.chunk_ntt_step_pure a z 0#usize 8#usize + let a2 := Spec.chunk_ntt_step_pure a1 z 1#usize 9#usize + let a3 := Spec.chunk_ntt_step_pure a2 z 2#usize 10#usize + let a4 := Spec.chunk_ntt_step_pure a3 z 3#usize 11#usize + let a5 := Spec.chunk_ntt_step_pure a4 z 4#usize 12#usize + let a6 := Spec.chunk_ntt_step_pure a5 z 5#usize 13#usize + let a7 := Spec.chunk_ntt_step_pure a6 z 6#usize 14#usize + Spec.chunk_ntt_step_pure a7 z 7#usize 15#usize + +/-- Pure inverse-NTT step at the chunk level. Mirrors the impl's + write order (`a[i] := add_pure a[j] a[i]`, then + `a[j] := mul_pure (sub_pure a[j] a[i_original]) zeta`). + Because the impl reads `vec[j]` (`= i1`) and `vec[i]` (`= i2`) + BEFORE writing, the `(i, j)` lanes become: + - new `a[i] = add_pure a[j] a[i]` (barrett collapses to canonical sum) + - new `a[j] = mul_pure (sub_pure a[j] a[i]) zeta` (Mont-mul with zeta) + where the reads on the RHS are at the ORIGINAL `a`. When `i = j` the + second write wins (matching impl semantics). -/ +noncomputable def Spec.chunk_inv_ntt_step_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (zeta : hacspec_ml_kem.parameters.FieldElement) (i j : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let a_i := a.val[i.val]! + let a_j := a.val[j.val]! + let new_i := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a_j a_i + let diff := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure a_j a_i + let new_j := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure diff zeta + (a.set i new_i).set j new_j + +/-- Pure projection of `vector.portable.ntt.inv_ntt_layer_1_step`: + 8 sequential `Spec.chunk_inv_ntt_step_pure` calls at disjoint lane pairs + `(0,2)(1,3)(4,6)(5,7)(8,10)(9,11)(12,14)(13,15)` with zetas + `z0,z0,z1,z1,z2,z2,z3,z3`. Mirrors `Spec.chunk_ntt_layer_1_step_pure` on the same lane-pair sequence but with the inverse + butterfly direction (`chunk_inv_ntt_step_pure` vs `chunk_ntt_step_pure`). -/ +noncomputable def Spec.chunk_inv_ntt_layer_1_step_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 z2 z3 : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let a1 := Spec.chunk_inv_ntt_step_pure a z0 0#usize 2#usize + let a2 := Spec.chunk_inv_ntt_step_pure a1 z0 1#usize 3#usize + let a3 := Spec.chunk_inv_ntt_step_pure a2 z1 4#usize 6#usize + let a4 := Spec.chunk_inv_ntt_step_pure a3 z1 5#usize 7#usize + let a5 := Spec.chunk_inv_ntt_step_pure a4 z2 8#usize 10#usize + let a6 := Spec.chunk_inv_ntt_step_pure a5 z2 9#usize 11#usize + let a7 := Spec.chunk_inv_ntt_step_pure a6 z3 12#usize 14#usize + Spec.chunk_inv_ntt_step_pure a7 z3 13#usize 15#usize + +/-- Pure projection of `vector.portable.ntt.inv_ntt_layer_2_step`: + 8 sequential `Spec.chunk_inv_ntt_step_pure` calls at disjoint lane pairs + `(0,4)(1,5)(2,6)(3,7)(8,12)(9,13)(10,14)(11,15)` with zetas + `z0,z0,z0,z0,z1,z1,z1,z1`. Mirror of `Spec.chunk_ntt_layer_2_step_pure`. -/ +noncomputable def Spec.chunk_inv_ntt_layer_2_step_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let a1 := Spec.chunk_inv_ntt_step_pure a z0 0#usize 4#usize + let a2 := Spec.chunk_inv_ntt_step_pure a1 z0 1#usize 5#usize + let a3 := Spec.chunk_inv_ntt_step_pure a2 z0 2#usize 6#usize + let a4 := Spec.chunk_inv_ntt_step_pure a3 z0 3#usize 7#usize + let a5 := Spec.chunk_inv_ntt_step_pure a4 z1 8#usize 12#usize + let a6 := Spec.chunk_inv_ntt_step_pure a5 z1 9#usize 13#usize + let a7 := Spec.chunk_inv_ntt_step_pure a6 z1 10#usize 14#usize + Spec.chunk_inv_ntt_step_pure a7 z1 11#usize 15#usize + +/-- Pure projection of `vector.portable.ntt.inv_ntt_layer_3_step`: + 8 sequential `Spec.chunk_inv_ntt_step_pure` calls at disjoint lane pairs + `(0,8)(1,9)(2,10)(3,11)(4,12)(5,13)(6,14)(7,15)` with a single zeta `z`. + Mirror of `Spec.chunk_ntt_layer_3_step_pure`. -/ +noncomputable def Spec.chunk_inv_ntt_layer_3_step_pure + (a : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let a1 := Spec.chunk_inv_ntt_step_pure a z 0#usize 8#usize + let a2 := Spec.chunk_inv_ntt_step_pure a1 z 1#usize 9#usize + let a3 := Spec.chunk_inv_ntt_step_pure a2 z 2#usize 10#usize + let a4 := Spec.chunk_inv_ntt_step_pure a3 z 3#usize 11#usize + let a5 := Spec.chunk_inv_ntt_step_pure a4 z 4#usize 12#usize + let a6 := Spec.chunk_inv_ntt_step_pure a5 z 5#usize 13#usize + let a7 := Spec.chunk_inv_ntt_step_pure a6 z 6#usize 14#usize + Spec.chunk_inv_ntt_step_pure a7 z 7#usize 15#usize + +/-- Pure accumulating NTT-multiply at the chunk level. Mirrors the impl + `vector.portable.ntt.accumulating_ntt_multiply`, + which fans out 8 calls of `accumulating_ntt_multiply_binomials` with + alternating ±zeta: + pair i ∈ {0..7}, zeta_i = [z0, -z0, z1, -z1, z2, -z2, z3, -z3][i] + For lane pair (2i, 2i+1): + - acc[2i] := acc[2i] + a[2i]·b[2i] + a[2i+1]·b[2i+1]·zeta_i + - acc[2i+1] := acc[2i+1] + a[2i]·b[2i+1] + a[2i+1]·b[2i] + All arithmetic in canonical `FieldElement` domain (the impl's + Montgomery `bj·ζ_mont → mont_reduce → bj·ζ_canonical` collapses + under `lift_fe_int`). -/ +noncomputable def Spec.chunk_accumulating_ntt_multiply_pure + (a b acc : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z0 z1 z2 z3 : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let zeta_for_pair (i : Nat) : hacspec_ml_kem.parameters.FieldElement := + if i = 0 then z0 + else if i = 1 then libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure z0 + else if i = 2 then z1 + else if i = 3 then libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure z1 + else if i = 4 then z2 + else if i = 5 then libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure z2 + else if i = 6 then z3 + else if i = 7 then libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure z3 + else defaultFE + Std.Array.make 16#usize + ((List.range 16).map (fun ℓ => + let i := ℓ / 2 + let ζ := zeta_for_pair i + if ℓ % 2 = 0 then + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure (acc.val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (a.val[ℓ]!) (b.val[ℓ]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (a.val[ℓ + 1]!) (b.val[ℓ + 1]!)) + ζ)) + else + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure (acc.val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (a.val[ℓ - 1]!) (b.val[ℓ]!)) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (a.val[ℓ]!) (b.val[ℓ - 1]!))))) + (by simp) + +/-- The PortableVector `Operations` instance used by Triples that + target the impl monomorphised at `PortableVector`. The concrete + instance is `vector.portable.vector_type.PortableVector.Insts. + Libcrux_iot_ml_kemVectorTraitsOperations` in `Extraction/Funs.lean`; + this alias decouples the FC statements from the precise extraction + identifier in case aeneas re-mangles the name later. -/ +@[reducible] def portable_ops_inst : + libcrux_iot_ml_kem.vector.traits.Operations + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector.Insts.Libcrux_iot_ml_kemVectorTraitsOperations + +/-- Local `Inhabited` for 16-element FE arrays, used by `[!]` indexing + inside `Spec.flatten_chunks`. -/ +private noncomputable instance instInhabitedFEChunk_fcTargets : + Inhabited (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) := + ⟨Std.Array.make 16#usize (List.replicate 16 defaultFE) (by simp)⟩ + +/-- Local `Inhabited` for the 256-FE poly-ring array, used by `[!]` indexing + inside `lift_matrix_from_slice`'s outer projection and the L6c + accumulator-lift family. -/ +private noncomputable instance instInhabitedFEPoly_fcTargets : + Inhabited (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) := + ⟨Std.Array.make 256#usize (List.replicate 256 defaultFE) List.length_replicate⟩ + +/-- Local `Inhabited` for the K-shape array-of-polys, used by `[!]` indexing + inside `lift_matrix_from_slice`'s outer projection and `lift_vec`. -/ +private noncomputable instance instInhabitedFEPolyVec_fcTargets + {K : Std.Usize} : + Inhabited (Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) K) := + ⟨Std.Array.make K (List.replicate K.val default) List.length_replicate⟩ + +/-- Per-index zeta lookup: project lane `i` of + `polynomial.ZETAS_TIMES_MONTGOMERY_R` into a canonical-domain FE. + The Mont-domain table holds `Std.I16` values; `lift_fe_mont` strips + one factor of R (yielding the canonical zeta). Out-of-range lookups + default to `lift_fe_mont 0 = 0` via `[!]`. -/ +noncomputable def Spec.zeta_at (i : Nat) : hacspec_ml_kem.parameters.FieldElement := + lift_fe_mont (libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[i]!) + +/-- Chunk projection: extract the `k`-th 16-element chunk of a 256-array. + Used to address the impl's `re.coefficients[k]` chunk slot at the + spec level. -/ +noncomputable def Spec.chunk_at + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) (k : Nat) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun j => p.val[16 * k + j]!)) + (by simp) + +/-- Flatten 16 chunks of 16 FEs into a 256-array. Inverse of + `Spec.chunk_at` under the `lift_poly` decomposition. -/ +noncomputable def Spec.flatten_chunks + (chunks : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Std.Array.make 256#usize ((List.range 256).map (fun j => + (chunks.val[j / 16]!).val[j % 16]!)) (by simp) + +/-- Pure projection of `ntt_at_layer_1` driver: 16 chunks, each chunk + transformed by `chunk_ntt_layer_1_step_pure` with 4 zetas drawn + from positions `zeta_i + 4k + {1..4}` in the global ZETAS table. -/ +noncomputable def Spec.ntt_layer_1_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_ntt_layer_1_step_pure (Spec.chunk_at p k) + (Spec.zeta_at (zeta_i.val + 4 * k + 1)) + (Spec.zeta_at (zeta_i.val + 4 * k + 2)) + (Spec.zeta_at (zeta_i.val + 4 * k + 3)) + (Spec.zeta_at (zeta_i.val + 4 * k + 4)))) + (by simp)) + +/-- Pure projection of `ntt_at_layer_2` driver: 16 chunks, each chunk + transformed by `chunk_ntt_layer_2_step_pure` with 2 zetas at + positions `zeta_i + 2k + {1, 2}`. -/ +noncomputable def Spec.ntt_layer_2_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_ntt_layer_2_step_pure (Spec.chunk_at p k) + (Spec.zeta_at (zeta_i.val + 2 * k + 1)) + (Spec.zeta_at (zeta_i.val + 2 * k + 2)))) + (by simp)) + +/-- Pure projection of `ntt_at_layer_3` driver: 16 chunks, each chunk + transformed by `chunk_ntt_layer_3_step_pure` with 1 zeta at + position `zeta_i + k + 1`. -/ +noncomputable def Spec.ntt_layer_3_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_ntt_layer_3_step_pure (Spec.chunk_at p k) + (Spec.zeta_at (zeta_i.val + k + 1)))) + (by simp)) + +/-- Pure projection of `invert_ntt.invert_ntt_at_layer_1` driver loop. 16 chunks; for chunk `k ∈ {0..15}` reads 4 zetas at + Mont-table indices `[zeta_i - 4k - 1, zeta_i - 4k - 2, zeta_i - 4k - 3, + zeta_i - 4k - 4]` (decreasing — opposite direction from the forward + layer-1 driver) and applies `chunk_inv_ntt_layer_1_step_pure`. The + impl initialises `zeta_i = 64` and decrements 4 per chunk, so the + indices read across all 16 chunks span `[zeta_i - 64 .. zeta_i - 1]`. + For the natural composer (top-level invert_ntt_montgomery) `zeta_i = + 64`, giving indices `[0..63]`. -/ +noncomputable def Spec.invert_ntt_layer_1_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_inv_ntt_layer_1_step_pure (Spec.chunk_at p k) + (Spec.zeta_at (zeta_i.val - 4 * k - 1)) + (Spec.zeta_at (zeta_i.val - 4 * k - 2)) + (Spec.zeta_at (zeta_i.val - 4 * k - 3)) + (Spec.zeta_at (zeta_i.val - 4 * k - 4)))) + (by simp)) + +/-- Pure projection of `invert_ntt.invert_ntt_at_layer_2` driver loop. 16 chunks; for chunk `k ∈ {0..15}` reads 2 zetas at + Mont-table indices `[zeta_i - 2k - 1, zeta_i - 2k - 2]` (decreasing) + and applies `chunk_inv_ntt_layer_2_step_pure`. The impl decrements + `zeta_i` by 2 per chunk, so indices span `[zeta_i - 32 .. zeta_i - 1]`. + Natural composer entry: `zeta_i = 32`, giving indices `[0..31]`. -/ +noncomputable def Spec.invert_ntt_layer_2_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_inv_ntt_layer_2_step_pure (Spec.chunk_at p k) + (Spec.zeta_at (zeta_i.val - 2 * k - 1)) + (Spec.zeta_at (zeta_i.val - 2 * k - 2)))) + (by simp)) + +/-- Pure projection of `invert_ntt.invert_ntt_at_layer_3` driver loop. 16 chunks; for chunk `k ∈ {0..15}` reads 1 zeta at + Mont-table index `zeta_i - k - 1` (decreasing) and applies + `chunk_inv_ntt_layer_3_step_pure`. The impl decrements `zeta_i` by 1 + per chunk, so indices span `[zeta_i - 16 .. zeta_i - 1]`. Natural + composer entry: `zeta_i = 16`, giving indices `[0..15]`. -/ +noncomputable def Spec.invert_ntt_layer_3_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_inv_ntt_layer_3_step_pure (Spec.chunk_at p k) + (Spec.zeta_at (zeta_i.val - k - 1)))) + (by simp)) + +/-- Pure INVERSE NTT (Gentleman-Sande) butterfly between TWO chunks, a-side. + Mirrors the impl `invert_ntt.inv_ntt_layer_int_vec_step_reduce` + on the a-side write: `new_a[ℓ] := barrett_reduce(a[ℓ] + b[ℓ])`, which under + `lift_fe_mont`'s canonical lift is simply `a[ℓ] + b[ℓ]` (no zeta on a-side + for the inverse direction). -/ +noncomputable def Spec.chunk_inv_pair_butterfly_a_pure + (chunk_a chunk_b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (chunk_a.val[ℓ]!) (chunk_b.val[ℓ]!))) + (by simp) + +/-- Pure INVERSE NTT (Gentleman-Sande) butterfly between TWO chunks, b-side. + Mirrors the impl b-side write: `new_b[ℓ] := mont_mul (2·b[ℓ] − barrett(a+b)) zeta_r`, + which under `lift_fe_mont`'s canonical lift collapses to + `(b[ℓ] − a[ℓ]) * z` (canonical, with `z = lift_fe_mont zeta_r` consuming + the Mont-domain `R⁻¹` of the impl's `mont_mul`). -/ +noncomputable def Spec.chunk_inv_pair_butterfly_b_pure + (chunk_a chunk_b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (chunk_b.val[ℓ]!) (chunk_a.val[ℓ]!)) + z)) + (by simp) + +/-- Per-chunk output for the INVERSE layer-4+ driver, parameterized by zeta + source. Mirror of `Spec.chunk_at_layer_4_plus_pure` but + using the inverse butterflies (`chunk_inv_pair_butterfly_{a,b}_pure`). + Chunk position `c ∈ 0..16`; step_vec/group/offset/partner relations same + as forward. -/ +noncomputable def Spec.chunk_inv_at_layer_4_plus_pure + (chunks : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize) + (layer : Std.Usize) (zeta_fn : Nat → hacspec_ml_kem.parameters.FieldElement) + (c : Nat) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let step_vec := (1 <<< layer.val) / 16 + let group := c / (2 * step_vec) + let offset := c % (2 * step_vec) + let z := zeta_fn group + if offset < step_vec then + Spec.chunk_inv_pair_butterfly_a_pure + (chunks.val[c]!) (chunks.val[c + step_vec]!) + else + Spec.chunk_inv_pair_butterfly_b_pure + (chunks.val[c - step_vec]!) (chunks.val[c]!) z + +/-- Pure projection of `invert_ntt.invert_ntt_at_layer_4_plus` for layers 4-7. + Iterates `128 >>> layer` outer rounds, each round processing `step_vec` + chunk-pairs at `(round*2*step_vec + j, round*2*step_vec + step_vec + j)` + for `j ∈ 0..step_vec`. zeta_i decrements by 1 per outer round, with the + constant zeta `polynomial.zeta (zeta_i_initial − 1 − round)` used across + each round's inner loop. + + Note: unlike the forward layer-4+ which uses `zeta_i + group + 1`, + inverse uses `zeta_i - 1 - group` (zeta_i decrements per outer iter). -/ +noncomputable def Spec.invert_ntt_layer_4_plus_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) (layer : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + let chunks0 : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at p)) (by simp) + let zeta_fn : Nat → hacspec_ml_kem.parameters.FieldElement := + fun group => Spec.zeta_at (zeta_i.val - 1 - group) + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun c => + Spec.chunk_inv_at_layer_4_plus_pure chunks0 layer zeta_fn c)) + (by simp)) + +/-- Pure projection of `invert_ntt.invert_ntt_montgomery` top-level composer. Initial `zeta_i = 128` (= `COEFFICIENTS_IN_RING_ELEMENT / 2`). + Composes seven layers in inverse order: layer 1, 2, 3, 4_plus(4), + 4_plus(5), 4_plus(6), 4_plus(7). zeta_i thread: + `128 → 64 → 32 → 16 → 8 → 4 → 2 → 1` (final, discarded). -/ +noncomputable def Spec.invert_ntt_montgomery_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + let p1 := Spec.invert_ntt_layer_1_pure p 128#usize + let p2 := Spec.invert_ntt_layer_2_pure p1 64#usize + let p3 := Spec.invert_ntt_layer_3_pure p2 32#usize + let p4 := Spec.invert_ntt_layer_4_plus_pure p3 16#usize 4#usize + let p5 := Spec.invert_ntt_layer_4_plus_pure p4 8#usize 5#usize + let p6 := Spec.invert_ntt_layer_4_plus_pure p5 4#usize 6#usize + Spec.invert_ntt_layer_4_plus_pure p6 2#usize 7#usize + +/-- Pure projection of `polynomial.PolynomialRingElement.accumulating_ntt_multiply`: + 16 chunks of accumulating NTT-multiplication. For chunk k ∈ {0..15}, + applies `chunk_accumulating_ntt_multiply_pure` with the 4 canonical-domain + zetas at `Spec.zeta_at (64 + 4*k + m)` for `m ∈ {0..3}` (matching the + impl's `polynomial.zeta` lookups at `64 + 4*k + m` per chunk —). -/ +noncomputable def Spec.accumulating_ntt_multiply_pure + (a b acc : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_accumulating_ntt_multiply_pure + (Spec.chunk_at a k) (Spec.chunk_at b k) (Spec.chunk_at acc k) + (Spec.zeta_at (64 + 4 * k)) + (Spec.zeta_at (64 + 4 * k + 1)) + (Spec.zeta_at (64 + 4 * k + 2)) + (Spec.zeta_at (64 + 4 * k + 3)))) + (by simp)) + +/-! ### Spec helpers for layer 4+ (cross-chunk butterflies). -/ + +/-- Pure NTT butterfly between TWO chunks, applied to all 16 lanes + simultaneously. Mirrors the impl's `ntt_layer_int_vec_step`: + lane ℓ in chunk_a becomes `chunk_a[ℓ] + chunk_b[ℓ] * z` (plain ZMod + via Montgomery cancellation in `lift_fe_mont`); lane ℓ in chunk_b + becomes `chunk_a[ℓ] - chunk_b[ℓ] * z`. -/ +noncomputable def Spec.chunk_pair_butterfly_a_pure + (chunk_a chunk_b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure (chunk_a.val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (chunk_b.val[ℓ]!) z))) + (by simp) + +noncomputable def Spec.chunk_pair_butterfly_b_pure + (chunk_a chunk_b : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + (z : hacspec_ml_kem.parameters.FieldElement) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure (chunk_a.val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (chunk_b.val[ℓ]!) z))) + (by simp) + +/-- Per-chunk output for the layer-4+ driver, parameterized by zeta source. + For chunk position `c ∈ 0..16`: + - `step_vec := (1 <<< layer) / 16` (= 1, 2, 4, 8 for layers 4..7). + - `group := c / (2 * step_vec)`, `offset := c % (2 * step_vec)`. + - If `offset < step_vec`: c is the a-side; partner is `c + step_vec`. + New chunk = chunk_a + chunk_partner * zeta_fn group. + - Else: c is the b-side; partner is `c - step_vec`. + New chunk = chunk_partner - chunk_c * zeta_fn group. + The `zeta_fn : Nat → FE` lets layer-4-6 use the zeta table and + layer-7 use the constant `lift_fe_mont (-1600)`. -/ +noncomputable def Spec.chunk_at_layer_4_plus_pure + (chunks : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize) + (layer : Std.Usize) (zeta_fn : Nat → hacspec_ml_kem.parameters.FieldElement) + (c : Nat) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + let step_vec := (1 <<< layer.val) / 16 + let group := c / (2 * step_vec) + let offset := c % (2 * step_vec) + let z := zeta_fn group + if offset < step_vec then + Spec.chunk_pair_butterfly_a_pure + (chunks.val[c]!) (chunks.val[c + step_vec]!) z + else + Spec.chunk_pair_butterfly_b_pure + (chunks.val[c - step_vec]!) (chunks.val[c]!) z + +/-- Pure projection of `ntt_at_layer_4_plus` driver for layers 4, 5, 6. + Iterates `2 * (128 >>> layer)` chunk-pair butterflies (= 16 chunks + touched once each), with zeta_offset incrementing every `step_vec` + inner butterflies (8 distinct zetas across the layer for layers 4-6). -/ +noncomputable def Spec.ntt_at_layer_4_plus_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (zeta_i : Std.Usize) (layer : Std.Usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + let chunks0 : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at p)) (by simp) + let zeta_fn : Nat → hacspec_ml_kem.parameters.FieldElement := + fun group => Spec.zeta_at (zeta_i.val + group + 1) + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun c => + Spec.chunk_at_layer_4_plus_pure chunks0 layer zeta_fn c)) + (by simp)) + +/-- The constant zeta used by `ntt_at_layer_7`. Impl uses + `multiply_by_constant scratch1 ((-1600)#i16)` (PLAIN multiplication, + not Mont — `multiply_by_constant_fc` lifts via `lift_fe`, not + `lift_fe_mont`). Lifted value is `lift_fe ((-1600)#i16)`, a fixed + element of the field. -/ +noncomputable def Spec.zeta_layer_7 : + hacspec_ml_kem.parameters.FieldElement := + lift_fe ((-1600)#i16) + +/-- Pure projection of `ntt_at_layer_7` driver. Single layer of 8 + chunk-pair butterflies between chunks `(j, j+8)` for j ∈ 0..8, all + with the constant zeta `Spec.zeta_layer_7`. -/ +noncomputable def Spec.ntt_at_layer_7_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + let chunks0 : Std.Array + (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) 16#usize := + Std.Array.make 16#usize ((List.range 16).map (Spec.chunk_at p)) (by simp) + let zeta_fn : Nat → hacspec_ml_kem.parameters.FieldElement := + fun _ => Spec.zeta_layer_7 + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun c => + Spec.chunk_at_layer_4_plus_pure chunks0 7#usize zeta_fn c)) + (by simp)) + +/-- Pure projection of the full hacspec `ntt.ntt`. Composes layer-7, + three layer-4_plus calls (layers 6, 5, 4), layer-3, layer-2, layer-1 + + final barrett, mirroring the impl `ntt_binomially_sampled_ring_element` + shape with cumulative zeta_i offsets: + - layer 7: zeta_i unchanged (constant zeta, no table use). + - layer 6: zeta_i starts at 1, advances by `128 >>> 6 = 2` to 3. + - layer 5: starts at 3, advances by `128 >>> 5 = 4` to 7. + - layer 4: starts at 7, advances by `128 >>> 4 = 8` to 15. + - layer 3: starts at 15, advances by 16 to 31. + - layer 2: starts at 31, advances by 32 to 63. + - layer 1: starts at 63, advances by 64 to 127. + Total zetas: 0 + 2 + 4 + 8 + 16 + 32 + 64 = 126 (indices 1..126 used). -/ +noncomputable def Spec.ntt_pure + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + let p7 := Spec.ntt_at_layer_7_pure p + let p6 := Spec.ntt_at_layer_4_plus_pure p7 1#usize 6#usize + let p5 := Spec.ntt_at_layer_4_plus_pure p6 3#usize 5#usize + let p4 := Spec.ntt_at_layer_4_plus_pure p5 7#usize 4#usize + let p3 := Spec.ntt_layer_3_pure p4 15#usize + let p2 := Spec.ntt_layer_2_pure p3 31#usize + let p1 := Spec.ntt_layer_1_pure p2 63#usize + Spec.Pure.polynomial.poly_barrett_reduce_pure p1 + +/-- Pure projection of `ntt_vector_u`'s full NTT chain. Mirrors `Spec.ntt_pure` + but uses `Spec.ntt_at_layer_4_plus_pure p 0 7` for the first step instead + of `Spec.ntt_at_layer_7_pure p`. The two specs are mathematically + equivalent in `ZMod 3329` (see `Spec.zeta_at_one_eq_layer_7` below: the + Mont-multiply layer-7 step via `ZETAS_TIMES_MONTGOMERY_R[1] = -758` and + the plain-multiply layer-7 step with constant `-1600` produce the same + field element). They differ structurally because `ntt_vector_u`'s impl + uses the Mont path while `ntt_binomially_sampled_ring_element` uses the + plain path; we target each spec at the impl actually used. -/ +noncomputable def Spec.ntt_pure_vec_u + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + let p7 := Spec.ntt_at_layer_4_plus_pure p 0#usize 7#usize + let p6 := Spec.ntt_at_layer_4_plus_pure p7 1#usize 6#usize + let p5 := Spec.ntt_at_layer_4_plus_pure p6 3#usize 5#usize + let p4 := Spec.ntt_at_layer_4_plus_pure p5 7#usize 4#usize + let p3 := Spec.ntt_layer_3_pure p4 15#usize + let p2 := Spec.ntt_layer_2_pure p3 31#usize + let p1 := Spec.ntt_layer_1_pure p2 63#usize + Spec.Pure.polynomial.poly_barrett_reduce_pure p1 + +/-- `ZETAS_TIMES_MONTGOMERY_R[1]! = -758#i16`. -/ +theorem Spec.ZETAS_TIMES_MONTGOMERY_R_get_one : + libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R.val[1]! + = ((-758)#i16 : Std.I16) := by + unfold libcrux_iot_ml_kem.polynomial.ZETAS_TIMES_MONTGOMERY_R + decide + +/-- Spec-level zeta equivalence between L3.7 (plain `multiply_by_constant`) + and L3.4_plus at layer=7 (Mont multiply through `ZETAS_TIMES_MONTGOMERY_R[1]`). + + In `ZMod 3329`: `Spec.zeta_at 1 = lift_fe_mont (-758) = lift_fe ((-758) * 169) + = lift_fe (-1600) = Spec.zeta_layer_7` (since `-758 * 169 ≡ -1600 mod 3329`). + Both equal the canonical field element 1729. -/ +theorem Spec.zeta_at_one_eq_layer_7 : + Spec.zeta_at 1 = Spec.zeta_layer_7 := by + unfold Spec.zeta_at Spec.zeta_layer_7 + rw [Spec.ZETAS_TIMES_MONTGOMERY_R_get_one] + unfold lift_fe_mont lift_fe + libcrux_iot_ml_kem.Spec.i16_to_spec_fe_mont + libcrux_iot_ml_kem.Spec.i16_to_spec_fe_plain + congr 1 + +/-- Per-chunk pure projection of `polynomial.add_error_reduce`: for the + `ℓ`-th lane of a 16-lane chunk, + `out[ℓ] := self_chunk[ℓ] · lift_fe_mont(1441#i16) + error_chunk[ℓ]`. -/ +noncomputable def Spec.chunk_add_error_reduce_pure + (self_chunk error_chunk : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (self_chunk.val[ℓ]!) (lift_fe_mont (1441#i16))) + (error_chunk.val[ℓ]!))) + (by simp) + +/-- Per-chunk pure projection of `polynomial.add_standard_error_reduce`: + for the `ℓ`-th lane, + `out[ℓ] := self_chunk[ℓ] · lift_fe_mont(1353#i16) + error_chunk[ℓ]`, + where `1353 ≡ R² (mod q)` (cf. `libcrux_iot_ml_kem.Spec.NumericKeystones.mont_1353_eq_RR_mod_q`). -/ +noncomputable def Spec.chunk_add_standard_error_reduce_pure + (self_chunk error_chunk : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (self_chunk.val[ℓ]!) (lift_fe_mont (1353#i16))) + (error_chunk.val[ℓ]!))) + (by simp) + +/-- Per-chunk pure projection of `polynomial.add_message_error_reduce`: + for the `ℓ`-th lane, + `out[ℓ] := result_chunk[ℓ] · lift_fe_mont(1441#i16) + + (self_chunk[ℓ] + message_chunk[ℓ])`. + The impl barrett-reduces this sum, but `barrett_pure` is identity + after `lift_fe`. -/ +noncomputable def Spec.chunk_add_message_error_reduce_pure + (self_chunk message_chunk result_chunk : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (result_chunk.val[ℓ]!) (lift_fe_mont (1441#i16))) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (self_chunk.val[ℓ]!) (message_chunk.val[ℓ]!)))) + (by simp) + +/-- Pure projection of `polynomial.add_error_reduce`. The hacspec spec + does not expose a dedicated `add_error_reduce` at the poly level — + the impl's behaviour is "multiply by R/128 then add error then + barrett". chunk `k ∈ 0..16` and lane `ℓ`: + `out_chunk[k][ℓ] := self[k][ℓ] · lift_fe_mont(1441#i16) + error[k][ℓ]`, + flattened to a 256-array via `Spec.flatten_chunks`. -/ +noncomputable def Spec.add_error_reduce_pure + (self error : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_add_error_reduce_pure + (Spec.chunk_at self k) (Spec.chunk_at error k))) + (by simp)) + +/-- Pure projection of `polynomial.add_standard_error_reduce`. chunk + `k` and lane `ℓ`: + `out[k][ℓ] := self[k][ℓ] · lift_fe_mont(1353#i16) + error[k][ℓ]` + (1353 ≡ R² mod q lifts to `× R` in canonical domain). -/ +noncomputable def Spec.add_standard_error_reduce_pure + (self error : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_add_standard_error_reduce_pure + (Spec.chunk_at self k) (Spec.chunk_at error k))) + (by simp)) + +/-- Pure projection of `polynomial.add_message_error_reduce`. chunk + `k` and lane `ℓ`: + `out[k][ℓ] := result[k][ℓ] · lift_fe_mont(1441#i16) + + (self[k][ℓ] + message[k][ℓ])`. -/ +noncomputable def Spec.add_message_error_reduce_pure + (self message : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (result : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_add_message_error_reduce_pure + (Spec.chunk_at self k) (Spec.chunk_at message k) (Spec.chunk_at result k))) + (by simp)) + +/-- Pure projection of poly-level `reducing_from_i32_array`. Direct 256-lane + construction: for `i ∈ 0..256`, + `out[i] := Spec.mont_reduce_pure (lift_fe_int array.val[i].val)`. + Mirrors `Spec.chunk_reducing_from_i32_array_pure` per chunk-of-16. -/ +noncomputable def Spec.poly_reducing_from_i32_array_pure + (array : Slice Std.I32) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Std.Array.make 256#usize + ((List.range 256).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (array.val[i]!).val))) + (by simp) + +/-- Per-chunk pure projection of `polynomial.subtract_reduce`: for the + `ℓ`-th lane of a 16-lane chunk, + `out[ℓ] := self_chunk[ℓ] - b_chunk[ℓ] * lift_fe_mont (1441#i16)`. + + This is the chunk-level building block used by `Spec.subtract_reduce_pure` + (which flattens 16 chunks via `Spec.flatten_chunks`). -/ +noncomputable def Spec.chunk_subtract_reduce_pure + (self_chunk b_chunk : Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize := + Std.Array.make 16#usize ((List.range 16).map (fun ℓ => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure (self_chunk.val[ℓ]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (b_chunk.val[ℓ]!) (lift_fe_mont (1441#i16))))) + (by simp) + +/-- Pure projection of `polynomial.subtract_reduce`. The hacspec spec + computes `self - b`, but the impl fuses a Mont-multiply on b by the + constant `1441#i16` BEFORE the subtract. the C.4 commute + `1441 · R⁻¹ ≡ 1441 · 169 ≡ 512 (mod q)`, this is equivalent in + ZMod q to computing `self - 512 · b` pointwise, NOT `self - b`. + + Hence we model the impl directly: per chunk `k ∈ 0..16` and lane + `ℓ ∈ 0..16`, + `out_chunk[k][ℓ] := self_chunk[k][ℓ] - b_chunk[k][ℓ] * lift_fe_mont (1441#i16)`, + then flatten 16 chunks to a 256-array via `Spec.flatten_chunks`. The + chunk-level form mirrors the impl's chunk-loop structure and pairs + with `flatten_chunks_eq_lift_poly_fc` in the FC closure proof. -/ +noncomputable def Spec.subtract_reduce_pure + (self b : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize := + Spec.flatten_chunks + (Std.Array.make 16#usize ((List.range 16).map (fun k => + Spec.chunk_subtract_reduce_pure + (Spec.chunk_at self k) (Spec.chunk_at b k))) + (by simp)) + +-- `Spec.sample_matrix_A_pure` is declared above (with `lift_matrix_from_seed`). + + +end libcrux_iot_ml_kem.Spec.Lift \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/ModularArith.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/ModularArith.lean new file mode 100644 index 00000000..a387863a --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/ModularArith.lean @@ -0,0 +1,102 @@ +/- + # `Util.ModularArith` — canonical modular-equality predicate for ML-KEM + + Defines `modq_eq a b q := (a - b) % q = 0` and its standard algebraic + lemma surface (reflexivity, symmetry, transitivity, additive / + subtractive / constant-multiplicative compatibility, and a bridge + to `ZMod q`). + + Every L0–L9 Triple postcondition that asserts `a ≡ b (mod q)` uses + this single named predicate. The subtraction-mod spelling is preferred + over `a % q = b % q` because it composes additively without side + conditions. +-/ +-- Mathlib footprint here is BARRIER-LAYER ONLY. Consumers of `modq_eq` +-- above the abstraction barrier MUST NOT import Mathlib themselves; +-- they use only the lemmas exported by this module. +import Mathlib.Data.ZMod.Basic +import Mathlib.Tactic.Ring + +namespace libcrux_iot_ml_kem.Spec.ModularArith +/-- Canonical modular-equality predicate: `a ≡ b (mod q)` in the + subtraction-mod spelling. + + Equivalent to `Int.ModEq q a b` (= `a % q = b % q`) but easier to + discharge directly via `decide` / `scalar_tac` on concrete values. -/ +def modq_eq (a b : Int) (q : Int) : Prop := (a - b) % q = 0 + +variable {a b c k : Int} {q : Int} + +@[simp] theorem modq_eq_refl : modq_eq a a q := by + unfold modq_eq + simp + +theorem modq_eq_symm : modq_eq a b q → modq_eq b a q := by + unfold modq_eq + intro h + have hdvd : q ∣ (a - b) := Int.dvd_of_emod_eq_zero h + have hdvd' : q ∣ (b - a) := by + have : (b - a) = -(a - b) := by ring + rw [this]; exact dvd_neg.mpr hdvd + exact Int.emod_eq_zero_of_dvd hdvd' + +theorem modq_eq_trans : modq_eq a b q → modq_eq b c q → modq_eq a c q := by + unfold modq_eq + intro h1 h2 + have hdvd1 : q ∣ (a - b) := Int.dvd_of_emod_eq_zero h1 + have hdvd2 : q ∣ (b - c) := Int.dvd_of_emod_eq_zero h2 + have hdvd : q ∣ (a - c) := by + have : (a - c) = (a - b) + (b - c) := by ring + rw [this]; exact dvd_add hdvd1 hdvd2 + exact Int.emod_eq_zero_of_dvd hdvd + +theorem modq_eq_add : + modq_eq a b q → modq_eq c d q → modq_eq (a + c) (b + d) q := by + unfold modq_eq + intro h1 h2 + have hdvd1 : q ∣ (a - b) := Int.dvd_of_emod_eq_zero h1 + have hdvd2 : q ∣ (c - d) := Int.dvd_of_emod_eq_zero h2 + have : ((a + c) - (b + d)) = (a - b) + (c - d) := by ring + rw [this] + exact Int.emod_eq_zero_of_dvd (dvd_add hdvd1 hdvd2) + +theorem modq_eq_sub : + modq_eq a b q → modq_eq c d q → modq_eq (a - c) (b - d) q := by + unfold modq_eq + intro h1 h2 + have hdvd1 : q ∣ (a - b) := Int.dvd_of_emod_eq_zero h1 + have hdvd2 : q ∣ (c - d) := Int.dvd_of_emod_eq_zero h2 + have : ((a - c) - (b - d)) = (a - b) - (c - d) := by ring + rw [this] + exact Int.emod_eq_zero_of_dvd (dvd_sub hdvd1 hdvd2) + +theorem modq_eq_const_mul : + modq_eq a b q → modq_eq (k * a) (k * b) q := by + unfold modq_eq + intro h + have hdvd : q ∣ (a - b) := Int.dvd_of_emod_eq_zero h + have : (k * a - k * b) = k * (a - b) := by ring + rw [this] + exact Int.emod_eq_zero_of_dvd (Dvd.dvd.mul_left hdvd k) + +/-- Bridge to mathlib's `ZMod` view: `modq_eq a b q` is equivalent to + `(a : ZMod q) = (b : ZMod q)`. Stated with `q : ℕ` plus an + explicit cast so we can reuse mathlib's `ZMod.intCast_eq_intCast_iff` + machinery; the `[NeZero q]` instance keeps the bridge usable on + every nonzero modulus (in particular `q = 3329`). -/ +theorem modq_eq_iff_zmod {q : ℕ} [NeZero q] (a b : Int) : + modq_eq a b q ↔ (a : ZMod q) = (b : ZMod q) := by + unfold modq_eq + rw [ZMod.intCast_eq_intCast_iff_dvd_sub] + refine ⟨fun h => ?_, fun h => ?_⟩ + · -- (a - b) % q = 0 → (q : Int) ∣ (b - a) + have hdvd : (q : Int) ∣ (a - b) := Int.dvd_of_emod_eq_zero h + have : (b - a) = -(a - b) := by ring + rw [this]; exact dvd_neg.mpr hdvd + · -- (q : Int) ∣ (b - a) → (a - b) % q = 0 + have hdvd : (q : Int) ∣ (a - b) := by + have : (a - b) = -(b - a) := by ring + rw [this]; exact dvd_neg.mpr h + exact Int.emod_eq_zero_of_dvd hdvd + +end libcrux_iot_ml_kem.Spec.ModularArith \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Montgomery.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Montgomery.lean new file mode 100644 index 00000000..b4ea9ced --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Montgomery.lean @@ -0,0 +1,131 @@ +/- + # `Util.Montgomery` — Montgomery-form algebraic bridges + + Pure `Int`-arithmetic helpers anchoring the Montgomery-reduction + family of Triples (L0.3, L0.4, L2.x, L6.x). + + This is **the L0.3 keystone module**. It contains the integer-level + "given that `(value + k·q) % R = 0`, the Montgomery identity + holds" lemma referenced from + Tier 1. Closing this file unblocks `montgomery_reduce_element_spec` + and (transitively) every downstream Triple that reasons about + values in the Montgomery domain. + + Conventions: `q = 3329`, `R = 2^16`. The lemmas are stated on + `Int` (not `Nat`), since signed Montgomery reduction in libcrux-iot + ML-KEM operates on `Std.I32` values whose `.val : Int` can be + negative. + + The downstream user uses `mont_reduce_int_form` immediately after + reducing the impl Triple to an `Int`-level chain via the BV-to-Int + bridge for `IScalar.cast`, `wrapping_mul`, and `>>>` (arithmetic + shift right). +-/ +-- Mathlib footprint here is BARRIER-LAYER ONLY. Consumers above the +-- abstraction barrier use only the `modq_eq`-shaped lemmas exported +-- from this file; they MUST NOT import Mathlib themselves. +-- `Linarith` dropped; both prior `linarith` calls migrated to `omega` +-- (core Lean), which handles linear int arithmetic with the same or +-- greater capability for the goals this file proves. +import LibcruxIotMlKem.Spec.ModularArith +import LibcruxIotMlKem.Spec.NumericKeystones +import Mathlib.Tactic.Ring + +namespace libcrux_iot_ml_kem.Spec.Montgomery +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.NumericKeystones +/-- **`mont_reduce_int_form`** — the L0.3 algebraic keystone. + + Given that `value + k · q` is divisible by `R = 2^16` (as an + integer), the Montgomery quotient `(value + k · q) / R` satisfies + `quotient · R ≡ value (mod q)`. + + This is the integer-arithmetic content of the signed-Montgomery- + reduction correctness argument. The `+` (vs `-`) here matches the + impl's two's-complement choice of `k`: the impl computes `k` so + that `value + k · q` cancels in the low 16 bits (the + `(62209 · 3329) % 2^16 = 1` keystone makes the choice + `k := value · 62209 (mod 2^16)` give that cancellation, with the + sign absorbed into the wrap-around). + + § "Reusable infrastructure" Tier 1. -/ +theorem mont_reduce_int_form + (value k : Int) + (h_div : (value + k * 3329) % (2^16 : Int) = 0) : + modq_eq ((value + k * 3329) / (2^16 : Int) * (2^16 : Int)) value 3329 := by + -- Strategy: from `h_div`, get `(value + k*3329) = 2^16 * m` for some `m`. + -- Then `(value + k*3329) / 2^16 = m`, so + -- `m * 2^16 - value = k * 3329`, which is divisible by 3329. + unfold modq_eq + have h_dvd : (2^16 : Int) ∣ (value + k * 3329) := + Int.dvd_of_emod_eq_zero h_div + obtain ⟨m, hm⟩ := h_dvd + have h_R_ne_zero : (2^16 : Int) ≠ 0 := by decide + have h_div_eq : (value + k * 3329) / (2^16 : Int) = m := by + rw [hm] + exact Int.mul_ediv_cancel_left m h_R_ne_zero + rw [h_div_eq] + -- Goal: (m * 2^16 - value) % 3329 = 0. + -- From `hm : value + k * 3329 = 2^16 * m`, we get + -- `m * 2^16 - value = k * 3329`. + have h_eq : m * (2^16 : Int) - value = k * 3329 := by omega + rw [h_eq] + exact Int.mul_emod_left _ _ + +/-- **`sub_div_of_emod_eq_zero`** — auxiliary used by L0.3. + + When `(a - b) % R = 0` (i.e. `a ≡ b (mod R)`), the difference of + floored quotients `a / R - b / R` equals the exact quotient + `(a - b) / R`. This is what bridges the impl's two separate + arithmetic-shift-right operations (`value >> 16` and + `(k · q) >> 16`) to the single mathematical quotient + `(value - k · q) / 2^16` that the keystone uses. + + Used by `montgomery_reduce_element_spec` (L0.3) immediately after + the BV-level reduction places the goal in `Int` form. -/ +theorem sub_div_of_emod_eq_zero + (a b R : Int) (hRne : R ≠ 0) (h_dvd : (a - b) % R = 0) : + a / R - b / R = (a - b) / R := by + have hd : R ∣ (a - b) := Int.dvd_of_emod_eq_zero h_dvd + obtain ⟨q, hq⟩ := hd + rw [hq, Int.mul_ediv_cancel_left q hRne] + have h_a : a = b + R * q := by omega + rw [h_a, Int.add_mul_ediv_left b q hRne] + ring + +/-- **Bridge: old → new Montgomery modq form.** + + Given the old-form modular equation `r * 2^16 ≡ v (mod 3329)`, + derive the new-form `r ≡ v * 169 (mod 3329)` via the + `mont_R_inv_q` keystone `(2^16 * 169) % 3329 = 1`. + + The keystone implies `r * (2^16 * 169) ≡ r (mod 3329)`, hence + multiplying both sides of the old form by 169 yields the new form. + + Used by L0.4 `montgomery_multiply_fe_by_fer_spec` and any + downstream Triple that needs to convert from the + `r·R ≡ v (mod q)` shape (impl-native) to the + `r ≡ v·R⁻¹ (mod q)` shape (F*-native, with `R⁻¹ = 169`). -/ +theorem modq_R_to_169 + (r v : Int) (h : modq_eq (r * (2^16 : Int)) v 3329) : + modq_eq r (v * 169) 3329 := by + unfold modq_eq at h ⊢ + have h_dvd_diff : (3329 : Int) ∣ (r * (2^16 : Int) - v) := Int.dvd_of_emod_eq_zero h + have h_keystone : ((2^16 : Int) * 169) % 3329 = 1 := by decide + have h_dvd_keystone : (3329 : Int) ∣ ((2^16 : Int) * 169 - 1) := by + have : ((2^16 : Int) * 169 - 1) % 3329 = 0 := by + rw [Int.sub_emod, h_keystone]; decide + exact Int.dvd_of_emod_eq_zero this + have h_dvd_r : (3329 : Int) ∣ (r * ((2^16 : Int) * 169) - r) := by + have h_eq : r * ((2^16 : Int) * 169) - r = r * ((2^16 : Int) * 169 - 1) := by ring + rw [h_eq] + exact Dvd.dvd.mul_left h_dvd_keystone r + have h_dvd_169 : (3329 : Int) ∣ ((r * (2^16 : Int) - v) * 169) := + Dvd.dvd.mul_right h_dvd_diff 169 + have h_dvd_final : (3329 : Int) ∣ (r - v * 169) := by + have h_eq : (r - v * 169) + = (r * (2^16 : Int) - v) * 169 - (r * ((2^16 : Int) * 169) - r) := by ring + rw [h_eq] + exact dvd_sub h_dvd_169 h_dvd_r + exact Int.emod_eq_zero_of_dvd h_dvd_final + +end libcrux_iot_ml_kem.Spec.Montgomery \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/NumericKeystones.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/NumericKeystones.lean new file mode 100644 index 00000000..6532097a --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/NumericKeystones.lean @@ -0,0 +1,76 @@ +/- + # `Util.NumericKeystones` — Montgomery / NTT keystone identities + + Concrete `Nat`-arithmetic identities (`(_ * _) % q = _`) that anchor + the ML-KEM Montgomery and inverse-NTT correctness arguments. Every + identity here closes by plain `decide` over small Nat arithmetic and + is referenced by name from the Triple-level proofs that need it. + + Conventions: `R = 2^16`, `q = 3329`. `R⁻¹ ≡ 169 (mod q)`. + + All proofs are `by decide` — never `native_decide` — so the kernel + proof term is small and the file's `#print axioms` reports only the + base `propext` / `Classical.choice` / `Quot.sound` triple (no + `Lean.ofReduceBool` / `Lean.trustCompiler`). +-/ + +namespace libcrux_iot_ml_kem.Spec.NumericKeystones +/-! ## Numeric keystones -/ + +/-- **B.1 `mont_R_inv_q`** — `R · 169 ≡ 1 (mod q)`. Used by every + Layer 0/2/3/6 lemma that converts between Montgomery and standard + domains (`montgomery_reduce_element_spec`, `montgomery_multiply_*`). + The load-bearing identity behind L0.3. -/ +theorem mont_R_inv_q : ((2^16 : Nat) * 169) % 3329 = 1 := by decide + +/-- **B.2 `mont_1441_eq_inv128`** — `1441 · 128 ≡ R² (mod q)`. Combined + with one Montgomery reduce (× R⁻¹), the net factor on the value + after `montgomery_multiply(b, 1441)` is `R / 128 mod q`. This is + exactly the "Montgomery-scale-by-1/128" used in `add_error_reduce`, + `subtract_reduce`, etc. to absorb the deferred 1/N normalization + of inverse NTT (L6.x). -/ +theorem mont_1441_eq_inv128 : + (1441 * 128) % 3329 = (2^16 * 2^16) % 3329 := by decide + +/-- **B.3 `mont_2285_eq_R_mod_q`** — `2285 ≡ 2^16 (mod q)`. Used in + `to_unsigned_field_modulus` to convert Montgomery-encoded → canonical + representative before serialization (L5.x). -/ +theorem mont_2285_eq_R_mod_q : 2285 = (2^16 : Nat) % 3329 := by decide + +/-- **B.4 `mont_1353_eq_RR_mod_q`** — `1353 ≡ R² (mod q)`. The Rust + function `to_standard_domain` is `montgomery_multiply(c, 1353)`; + it lifts `x` to `R · x mod q` (since `x · R² · R⁻¹ = R · x`). + Used by Layer 3 (NTT pre-domain) and Layer 6 (post-NTT lift). -/ +theorem mont_1353_eq_RR_mod_q : 1353 = (2^16 * 2^16) % 3329 := by decide + +/-! ## Additional keystones -/ + +/-- **`mont_qinv_R`** — `Q⁻¹_R · q ≡ 1 (mod R)`, the dual of + `mont_R_inv_q`. With `Q⁻¹_R = 62209` (the precomputed Montgomery + constant for `q = 3329, R = 2^16`) and `R = 2^16`, this is + `(62209 · 3329) % 2^16 = 1`. The load-bearing identity for + `montgomery_reduce_element_spec` (L0.3): it is what makes the + integer formula `(value - k*q) / R` produce an exact integer + rather than a quotient with leftover bits. + + Note: while `mont_R_inv_q` lives mod q (B.1), this lemma lives + mod R — together they pin down both halves of the Montgomery + reciprocal pair. -/ +theorem mont_qinv_R : (62209 * 3329) % (2^16 : Nat) = 1 := by decide + +/-- **`mont_128_169_512`** — INTT finalize keystone: + `1441 · 169 ≡ 512 (mod q)`. + + Semantically this is "after multiplying by `1441` (= `R / 128 mod q`, + B.2) and one Montgomery reduce (× R⁻¹ = 169, B.1), the leftover + factor is `512 = 2^9 = R / 128`" — i.e. the deferred 1/128 from + INTT comes out as the canonical small constant `512` in the + standard domain. The literal `128` in the symbol name refers to + the 1/128 normalization factor that this identity finalizes. + + Used by L6.4 (`subtract_reduce`) and the assembly bridges that + funnel the post-INTT Montgomery state down to canonical + representatives. -/ +theorem mont_128_169_512 : (1441 * 169) % 3329 = 512 := by decide + +end libcrux_iot_ml_kem.Spec.NumericKeystones \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Pure.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Pure.lean new file mode 100644 index 00000000..a74a13d6 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/Pure.lean @@ -0,0 +1,1056 @@ +/- + # `Spec/Pure.lean` — Open Question I.7 resolution. + + The hacspec ML-KEM extraction (`HacspecMlKem.Extraction.Funs`) wraps + every spec function in the Aeneas `Result` monad, even when the + body is mathematically pure (no panic / divergence). The bit-side + intermediate spec (`BitMlKem.Spec`) operates on `MontPoly = + Vector (ZMod 3329) 256`, which is genuinely pure. + + Layer M.4 alg-equiv lemmas state equations of the form + `bit_ (lift hacspec_input) = lift (Spec._pure hacspec_input)` + where `Spec._pure` is the `Result`-stripped pure projection of + the hacspec spec. This file defines those `_pure` aliases. + + arch plan §F.2 option (b): each alias is defined by pattern + match on the `Result`. The companion **pure-projection side lemmas** + of the form `Spec. args = .ok (Spec._pure args)` pin the + impl's `.ok` value to the projected `_pure` value. They are the + equational input to `libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq` (the index-wise spec + lifting `from_fn`) and to downstream M.2 commute lemmas — NOT + standalone "panic-freedom" facts. + + ## Scope + + - Scalar `parameters.FieldElement.{add,sub,mul,neg}` get `_pure` + aliases — they are used pointwise inside every `polynomial.*` + closure in the hacspec extraction. + - The three poly-level hacspec wrappers used by M.4 easy cluster + (`polynomial.add_to_ring_element`, `polynomial.poly_barrett_reduce`, + `polynomial.subtract_reduce`) get `_pure` aliases. + + ## Discipline + + - All `_pure` defs are `noncomputable` because the match-on-Result + extraction does not reduce by `decide` for arbitrary inputs; + callers reason about them via the side lemmas (TODO) or by + direct `simp only [_pure]` rewriting through M.4 proofs. + - No Mathlib imports needed; this file lives above the M.1 barrier + inheriting `ZMod 3329` transitively via `BitMlKem.Spec`. +-/ +import LibcruxIotMlKem.Spec +import LibcruxIotMlKem.Util.CreateI + +namespace libcrux_iot_ml_kem.Spec.Pure +open CoreModels Aeneas Aeneas.Std +open hacspec_ml_kem + +/-! ## Default `FieldElement` (used as fall-through in `_pure` match). -/ + +/-- Canonical default for `parameters.FieldElement` — the zero element. + Used as the fall-through branch in the `match` definitions below; + chosen to be in canonical range so the lift through `zmodOfFE` + gives `0 : ZMod 3329`. -/ +def defaultFE : parameters.FieldElement := { val := (0#u16 : Std.U16) } + +instance : Inhabited parameters.FieldElement := ⟨defaultFE⟩ + +/-! ## Scalar `_pure` aliases. -/ + +/-- Pure projection of `parameters.FieldElement.add`. The hacspec body + computes `(self.val + other.val) % q` via U32 lifts; this `_pure` + extracts the `.ok` value (see `FieldElement.add_eq_ok` below for + the pure-projection side lemma pinning the result). -/ +noncomputable def FieldElement.add_pure + (self other : parameters.FieldElement) : parameters.FieldElement := + match parameters.FieldElement.add self other with + | .ok r => r + | _ => defaultFE + +/-- Pure projection of `parameters.FieldElement.sub`. Mirrors + `add_pure`; hacspec body computes `(self.val + q - other.val) % q`. -/ +noncomputable def FieldElement.sub_pure + (self other : parameters.FieldElement) : parameters.FieldElement := + match parameters.FieldElement.sub self other with + | .ok r => r + | _ => defaultFE + +/-- Pure projection of `parameters.FieldElement.mul`. -/ +noncomputable def FieldElement.mul_pure + (self other : parameters.FieldElement) : parameters.FieldElement := + match parameters.FieldElement.mul self other with + | .ok r => r + | _ => defaultFE + +/-- Pure projection of `parameters.FieldElement.neg`. -/ +noncomputable def FieldElement.neg_pure + (self : parameters.FieldElement) : parameters.FieldElement := + match parameters.FieldElement.neg self with + | .ok r => r + | _ => defaultFE + +/-! ## Poly-level `_pure` aliases. + + These wrap the 256-coefficient Aeneas-`Array` results into a pure + `SpecPoly = Vector parameters.FieldElement 256` via per-lane + extraction. Downstream lifting through `libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq` + pins each lane to its per-element `_pure` value; see the + `polynomial._eq_ok` pure-projection side lemmas at the end of + this file. -/ + +/-- Pure projection of `polynomial.add_to_ring_element`. -/ +noncomputable def polynomial.add_to_ring_element_pure + (lhs rhs : Std.Array parameters.FieldElement 256#usize) : + Std.Array parameters.FieldElement 256#usize := + match hacspec_ml_kem.polynomial.add_to_ring_element lhs rhs with + | .ok r => r + | _ => lhs + +/-- Pure projection of `polynomial.poly_barrett_reduce`. -/ +noncomputable def polynomial.poly_barrett_reduce_pure + (p : Std.Array parameters.FieldElement 256#usize) : + Std.Array parameters.FieldElement 256#usize := + match hacspec_ml_kem.polynomial.poly_barrett_reduce p with + | .ok r => r + | _ => p + +/-- Pure projection of `polynomial.subtract_reduce`. -/ +noncomputable def polynomial.subtract_reduce_pure + (a b : Std.Array parameters.FieldElement 256#usize) : + Std.Array parameters.FieldElement 256#usize := + match hacspec_ml_kem.polynomial.subtract_reduce a b with + | .ok r => r + | _ => a + +/-! ## FE-primitive pure-projection side lemmas. + + `add_eq_ok` and `mul_eq_ok` are UNCONDITIONALLY true because U32 + widening prevents overflow in both cases (`a.val + b.val < 2^17`, + `a.val * b.val ≤ (2^16 - 1)² < 2^32`). + + `sub_eq_ok` and `neg_eq_ok` require a CANONICITY PRECONDITION + `Canonical a := a.val.val < q`. The Aeneas extraction dropped the + Rust source's `#[refine(val < FIELD_MODULUS)]` refinement on + `parameters.FieldElement::new` (see `specs/ml-kem/src/parameters.rs:319`), + leaving any `U16` admissible. Counterexamples for arbitrary U16 + (verified via `#eval`): + * `sub ⟨0#u16⟩ ⟨65535#u16⟩ = .fail` (U32 underflow after `+ q`). + * `neg ⟨65535#u16⟩ = .fail` (U16 underflow `q - val`, no widening). + + Audit of the 3 call-sites of `sub`/`neg` in the + hacspec extraction confirms ALL callers feed canonical inputs: + * `ntt.butterfly` and `invert_ntt.inv_butterfly` — inputs `a, b` + assumed canonical by caller; `mul`'s output is always canonical + by construction. + * `ntt.base_case_multiply_n` — `neg fe` where `fe` is from the + compile-time zetas table (all canonical). + + Canonicity is also the natural invariant of the bit-side lifts: + `feOfZMod : ZMod 3329 → FieldElement` produces canonical FEs by + construction (the lift packs `z.val < 3329 < 2^16` into a U16). + + Canonicity preservation lemmas (`Canonical_add_pure`, + `Canonical_mul_pure`, etc.) are deferred to .2 when the + poly-wrapper side lemmas need them; the U16-cast simp residue + requires non-trivial unfolding that's better tackled with + downstream context. -/ + +/-- A `parameters.FieldElement` is canonical iff its underlying U16 + holds a value strictly below the field modulus `q = 3329`. This + is the invariant the Rust source maintains via + `#[refine(val < FIELD_MODULUS)]` on `FieldElement::new` — the + Aeneas extraction drops the refinement, so we carry canonicity + as an explicit predicate. + + The bit-side lift `feOfZMod` produces canonical FEs by + construction; downstream wrappers (`butterfly`, etc.) take + canonical inputs and produce canonical outputs. -/ +def Canonical (fe : parameters.FieldElement) : Prop := + fe.val.val < parameters.FIELD_MODULUS.val + +private theorem uscalar_rem_ok_U32 (z m : Std.U32) (hm : m.val ≠ 0) : + ∃ w : Std.U32, (z % m : Result Std.U32) = .ok w ∧ w.val = z.val % m.val := by + have heq : (z % m : Result Std.U32) = Std.UScalar.rem z m := rfl + unfold Std.UScalar.rem at heq + simp [hm] at heq + refine ⟨_, heq, ?_⟩ + show (BitVec.umod z.bv m.bv).toNat = z.val % m.val + unfold BitVec.umod + simp only [BitVec.toNat_ofNatLT] + rfl + +/-- U16 variant of `uscalar_rem_ok_U32` — used by `neg_eq_ok` whose + `% q` step is at U16 width (no widening). -/ +private theorem uscalar_rem_ok_U16 (z m : Std.U16) (hm : m.val ≠ 0) : + ∃ w : Std.U16, (z % m : Result Std.U16) = .ok w ∧ w.val = z.val % m.val := by + have heq : (z % m : Result Std.U16) = Std.UScalar.rem z m := rfl + unfold Std.UScalar.rem at heq + simp [hm] at heq + refine ⟨_, heq, ?_⟩ + show (BitVec.umod z.bv m.bv).toNat = z.val % m.val + unfold BitVec.umod + simp only [BitVec.toNat_ofNatLT] + rfl + +/-- Pure-projection side lemma for `parameters.FieldElement.add` — + unconditional over ALL `FieldElement` inputs (no canonicity + precondition). The U32 widening bounds the sum strictly below + `2^32` so every intermediate step is `.ok`. -/ +theorem FieldElement.add_eq_ok (a b : parameters.FieldElement) : + parameters.FieldElement.add a b = .ok (FieldElement.add_pure a b) := by + unfold FieldElement.add_pure + suffices h : ∃ r, parameters.FieldElement.add a b = .ok r by + obtain ⟨r, hr⟩ := h; rw [hr] + unfold parameters.FieldElement.add + simp only [lift, bind_tc_ok] + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hae := Std.UScalar.add_equiv x y + cases hxy : (x + y) with + | ok z => + rw [hxy] at hae; simp at hae + obtain ⟨_, _, _⟩ := hae + have hmod_val : + (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val = 3329 := by + unfold parameters.FIELD_MODULUS; simp + have hmod_ne : + (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val ≠ 0 := by + rw [hmod_val]; decide + set m : Std.U32 := Std.UScalar.cast .U32 parameters.FIELD_MODULUS + obtain ⟨w, hw_eq, _⟩ := uscalar_rem_ok_U32 z m hmod_ne + simp only [bind_tc_ok, hw_eq] + exact ⟨_, rfl⟩ + | fail e => + rw [hxy] at hae; simp [Std.UScalar.inBounds] at hae + rw [hxval, hyval] at hae; omega + | div => rw [hxy] at hae; exact hae.elim + +/-- Pure-projection side lemma for `parameters.FieldElement.mul` — + unconditional over ALL `FieldElement` inputs. The U32 widening + bounds the product strictly below `2^32`. -/ +theorem FieldElement.mul_eq_ok (a b : parameters.FieldElement) : + parameters.FieldElement.mul a b = .ok (FieldElement.mul_pure a b) := by + unfold FieldElement.mul_pure + suffices h : ∃ r, parameters.FieldElement.mul a b = .ok r by + obtain ⟨r, hr⟩ := h; rw [hr] + unfold parameters.FieldElement.mul + simp only [lift, bind_tc_ok] + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hae := Std.UScalar.mul_equiv x y + have heqmul : (x * y : Result Std.U32) = Std.UScalar.mul x y := rfl + cases hxy : (x * y : Result Std.U32) with + | ok z => + rw [heqmul] at hxy; rw [hxy] at hae; simp at hae + obtain ⟨_, _, _⟩ := hae + have hmod_val : + (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val = 3329 := by + unfold parameters.FIELD_MODULUS; simp + have hmod_ne : + (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val ≠ 0 := by + rw [hmod_val]; decide + set m : Std.U32 := Std.UScalar.cast .U32 parameters.FIELD_MODULUS + obtain ⟨w, hw_eq, _⟩ := uscalar_rem_ok_U32 z m hmod_ne + simp only [bind_tc_ok, hw_eq] + exact ⟨_, rfl⟩ + | fail e => + rw [heqmul] at hxy; rw [hxy] at hae + simp only [Std.UScalar.max, Std.UScalarTy.numBits] at hae + rw [hxval, hyval] at hae + have : a.val.val * b.val.val < 2^32 := by + have h1 : a.val.val * b.val.val ≤ (2^16 - 1) * (2^16 - 1) := by + apply Nat.mul_le_mul <;> omega + have heq : (2^16 - 1) * (2^16 - 1) = 2^32 - 2*2^16 + 1 := by decide + omega + omega + | div => rw [heqmul] at hxy; rw [hxy] at hae; exact hae.elim + +/-- Pure-projection side lemma for `parameters.FieldElement.sub` — + valid only for CANONICAL inputs. The U32 sum `a.val + q` reaches + `q + (q-1) < 2·q` without overflow (q = 3329), and the subsequent + `% q` is well-defined since q ≠ 0. + + For non-canonical inputs the impl can `.fail` — see the doc block + above and `parameters.FieldElement.sub ⟨0⟩ ⟨65535⟩` counterexample. -/ +theorem FieldElement.sub_eq_ok (a b : parameters.FieldElement) + (ha : Canonical a) (hb : Canonical b) : + parameters.FieldElement.sub a b = .ok (FieldElement.sub_pure a b) := by + unfold Canonical at ha hb + unfold parameters.FIELD_MODULUS at ha hb + simp at ha hb + unfold FieldElement.sub_pure + suffices h : ∃ r, parameters.FieldElement.sub a b = .ok r by + obtain ⟨r, hr⟩ := h; rw [hr] + unfold parameters.FieldElement.sub + simp only [lift, bind_tc_ok] + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + set q : Std.U32 := Std.UScalar.cast .U32 parameters.FIELD_MODULUS + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hqval : q.val = 3329 := by + show (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val = 3329 + unfold parameters.FIELD_MODULUS; simp + have hae := Std.UScalar.add_equiv x q + cases hxq : (x + q : Result Std.U32) with + | ok s => + rw [hxq] at hae; simp at hae + obtain ⟨_, hsval, _⟩ := hae + simp only [bind_tc_ok] + have hae2 := Std.UScalar.sub_equiv s y + cases hsy : (s - y : Result Std.U32) with + | ok u => + rw [hsy] at hae2; simp at hae2 + simp only [bind_tc_ok] + have hq_ne : q.val ≠ 0 := by rw [hqval]; decide + obtain ⟨w, hw_eq, _⟩ := uscalar_rem_ok_U32 u q hq_ne + rw [hw_eq]; simp only [bind_tc_ok] + exact ⟨_, rfl⟩ + | fail e => + rw [hsy] at hae2; simp [] at hae2 + rw [hsval, hxval, hqval, hyval] at hae2 + omega + | div => rw [hsy] at hae2; exact hae2.elim + | fail e => + rw [hxq] at hae; simp [Std.UScalar.inBounds] at hae + rw [hxval, hqval] at hae + omega + | div => rw [hxq] at hae; exact hae.elim + +/-- Pure-projection side lemma for `parameters.FieldElement.neg` — + valid only for CANONICAL input. The impl computes `q - self.val` + in U16 (NO widening), which only avoids underflow when + `self.val ≤ q`. The subsequent `% q` is well-defined since q ≠ 0. + + For non-canonical inputs the impl can `.fail` — see the doc block + above and `parameters.FieldElement.neg ⟨65535⟩` counterexample. -/ +theorem FieldElement.neg_eq_ok (a : parameters.FieldElement) + (ha : Canonical a) : + parameters.FieldElement.neg a = .ok (FieldElement.neg_pure a) := by + unfold Canonical at ha + unfold parameters.FIELD_MODULUS at ha + simp at ha + unfold FieldElement.neg_pure + suffices h : ∃ r, parameters.FieldElement.neg a = .ok r by + obtain ⟨r, hr⟩ := h; rw [hr] + unfold parameters.FieldElement.neg + have hA := a.val.hBounds + simp [Std.UScalarTy.numBits] at hA + have hqval : (parameters.FIELD_MODULUS : Std.U16).val = 3329 := by + unfold parameters.FIELD_MODULUS; simp + have hae := Std.UScalar.sub_equiv (parameters.FIELD_MODULUS : Std.U16) a.val + cases hqa : ((parameters.FIELD_MODULUS : Std.U16) - a.val : Result Std.U16) with + | ok i => + rw [hqa] at hae; simp at hae + obtain ⟨_, _, _⟩ := hae + simp only [bind_tc_ok] + have hq_ne : (parameters.FIELD_MODULUS : Std.U16).val ≠ 0 := by + rw [hqval]; decide + obtain ⟨w, hw_eq, _⟩ := uscalar_rem_ok_U16 i parameters.FIELD_MODULUS hq_ne + rw [hw_eq]; simp only [bind_tc_ok] + exact ⟨_, rfl⟩ + | fail e => + rw [hqa] at hae; simp [] at hae + rw [hqval] at hae + omega + | div => rw [hqa] at hae; exact hae.elim + +/-! ## Canonicity preservation lemmas. + + Each `Canonical__pure` shows the `_pure` projection produces a + canonical FE. Proof shape (mirrors the corresponding `_eq_ok` side + lemma): chain through the do-block via `_equiv`, extract `w` + from the final `% q` via `uscalar_rem_ok_U{32,16}`, then bound + `(UScalar.cast .U16 w).val = w.val < 3329` via + `Std.UScalar.cast_val_mod_pow_of_inBounds_eq`. -/ + +/-- Canonicity preservation for `FieldElement.add_pure`. Unconditional: + the result is the modular reduction `(a.val + b.val) % q < q`. + + Proof strategy: rewrite `parameters.FieldElement.add` to `.ok` + via `add_eq_ok`, unfold the do-block, case on the U32 `+`/`%` + branches (the `.ok` branch fires by assumption; `.fail`/`.div` + discharge via `add_equiv` + `omega`/`hae.elim`), and bound the + final U16 cast through `cast_val_mod_pow_of_inBounds_eq`. -/ +theorem Canonical_add_pure (a b : parameters.FieldElement) : + Canonical (FieldElement.add_pure a b) := by + have hadd : parameters.FieldElement.add a b = .ok (FieldElement.add_pure a b) := + FieldElement.add_eq_ok a b + unfold parameters.FieldElement.add at hadd + simp only [lift, bind_tc_ok] at hadd + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hae := Std.UScalar.add_equiv x y + cases hxy : (x + y) with + | ok z => + rw [hxy] at hae hadd; simp at hae + obtain ⟨_, _, _⟩ := hae + simp only [bind_tc_ok] at hadd + have hmod_val : + (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val = 3329 := by + unfold parameters.FIELD_MODULUS; simp + have hmod_ne : + (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val ≠ 0 := by + rw [hmod_val]; decide + set m : Std.U32 := Std.UScalar.cast .U32 parameters.FIELD_MODULUS + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U32 z m hmod_ne + rw [hw_eq] at hadd; simp only [bind_tc_ok] at hadd + unfold parameters.FieldElement.new at hadd + simp at hadd + have hwbnd : w.val < 3329 := by + rw [hwval, hmod_val]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Std.UScalarTy.numBits]; omega + unfold Canonical + rw [← hadd] + show (Std.UScalar.cast .U16 w).val < parameters.FIELD_MODULUS.val + unfold parameters.FIELD_MODULUS + simp + rw [hwcast]; exact hwbnd + | fail e => + rw [hxy] at hae; simp [Std.UScalar.inBounds] at hae + rw [hxval, hyval] at hae; omega + | div => rw [hxy] at hae; exact hae.elim + +/-- Canonicity preservation for `FieldElement.mul_pure`. Unconditional. + + Proof strategy: same shape as `Canonical_add_pure` but using + `mul_equiv` and the `mul`-impl's U32 product `% q`. The `.fail` + branch's panic-impossibility follows from + `a.val * b.val ≤ (2^16-1)^2 < 2^32`. -/ +theorem Canonical_mul_pure (a b : parameters.FieldElement) : + Canonical (FieldElement.mul_pure a b) := by + have hmul : parameters.FieldElement.mul a b = .ok (FieldElement.mul_pure a b) := + FieldElement.mul_eq_ok a b + unfold parameters.FieldElement.mul at hmul + simp only [lift, bind_tc_ok] at hmul + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hae := Std.UScalar.mul_equiv x y + have heqmul : (x * y : Result Std.U32) = Std.UScalar.mul x y := rfl + cases hxy : (x * y : Result Std.U32) with + | ok z => + rw [hxy] at hmul + rw [heqmul] at hxy + rw [hxy] at hae; simp at hae + obtain ⟨_, _, _⟩ := hae + simp only [bind_tc_ok] at hmul + have hmod_val : + (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val = 3329 := by + unfold parameters.FIELD_MODULUS; simp + have hmod_ne : + (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val ≠ 0 := by + rw [hmod_val]; decide + set m : Std.U32 := Std.UScalar.cast .U32 parameters.FIELD_MODULUS + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U32 z m hmod_ne + rw [hw_eq] at hmul; simp only [bind_tc_ok] at hmul + unfold parameters.FieldElement.new at hmul + simp at hmul + have hwbnd : w.val < 3329 := by + rw [hwval, hmod_val]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Std.UScalarTy.numBits]; omega + unfold Canonical + rw [← hmul] + show (Std.UScalar.cast .U16 w).val < parameters.FIELD_MODULUS.val + unfold parameters.FIELD_MODULUS + simp + rw [hwcast]; exact hwbnd + | fail e => + rw [heqmul] at hxy; rw [hxy] at hae + simp only [Std.UScalar.max, Std.UScalarTy.numBits] at hae + rw [hxval, hyval] at hae + have : a.val.val * b.val.val < 2^32 := by + have h1 : a.val.val * b.val.val ≤ (2^16 - 1) * (2^16 - 1) := by + apply Nat.mul_le_mul <;> omega + have heq : (2^16 - 1) * (2^16 - 1) = 2^32 - 2*2^16 + 1 := by decide + omega + omega + | div => rw [heqmul] at hxy; rw [hxy] at hae; exact hae.elim + +/-- Canonicity preservation for `FieldElement.sub_pure`. Requires both + inputs canonical (the impl panics on non-canonical inputs — see + `sub_eq_ok`'s precondition). + + Proof strategy: two-step do-block — first `x + q` (panic-impossible + by `hb` canonical, since `x.val + q.val ≤ q.val + (q.val-1) < 2^32`), + then `s - y` (panic-impossible by `hb` canonical, since + `s.val = x.val + q.val ≥ q.val > y.val`), then `% q` and U16 cast + as in `Canonical_add_pure`. -/ +theorem Canonical_sub_pure (a b : parameters.FieldElement) + (ha : Canonical a) (hb : Canonical b) : + Canonical (FieldElement.sub_pure a b) := by + have hsub : parameters.FieldElement.sub a b = .ok (FieldElement.sub_pure a b) := + FieldElement.sub_eq_ok a b ha hb + unfold Canonical at ha hb + unfold parameters.FIELD_MODULUS at ha hb + simp at ha hb + unfold parameters.FieldElement.sub at hsub + simp only [lift, bind_tc_ok] at hsub + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + set q : Std.U32 := Std.UScalar.cast .U32 parameters.FIELD_MODULUS + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hqval : q.val = 3329 := by + show (Std.UScalar.cast .U32 parameters.FIELD_MODULUS).val = 3329 + unfold parameters.FIELD_MODULUS; simp + have hae := Std.UScalar.add_equiv x q + cases hxq : (x + q : Result Std.U32) with + | ok s => + rw [hxq] at hae hsub; simp at hae + obtain ⟨_, hsval, _⟩ := hae + simp only [bind_tc_ok] at hsub + have hae2 := Std.UScalar.sub_equiv s y + cases hsy : (s - y : Result Std.U32) with + | ok u => + rw [hsy] at hae2 hsub; simp at hae2 + obtain ⟨_, _, _⟩ := hae2 + simp only [bind_tc_ok] at hsub + have hq_ne : q.val ≠ 0 := by rw [hqval]; decide + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U32 u q hq_ne + rw [hw_eq] at hsub; simp only [bind_tc_ok] at hsub + unfold parameters.FieldElement.new at hsub + simp at hsub + have hwbnd : w.val < 3329 := by + rw [hwval, hqval]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Std.UScalarTy.numBits]; omega + unfold Canonical + rw [← hsub] + show (Std.UScalar.cast .U16 w).val < parameters.FIELD_MODULUS.val + unfold parameters.FIELD_MODULUS + simp + rw [hwcast]; exact hwbnd + | fail e => + rw [hsy] at hae2; simp at hae2 + rw [hsval, hxval, hqval, hyval] at hae2 + omega + | div => rw [hsy] at hae2; exact hae2.elim + | fail e => + rw [hxq] at hae; simp [Std.UScalar.inBounds] at hae + rw [hxval, hqval] at hae + omega + | div => rw [hxq] at hae; exact hae.elim + +/-- Canonicity preservation for `FieldElement.neg_pure`. Requires + the input canonical. + + Proof strategy: the impl does NOT widen to U32 — it computes + `q - self.val` directly in U16 (panic-impossible by `ha` + canonical, since `q.val > a.val.val`), then `% q` at U16 width + via `uscalar_rem_ok_U16`. The output IS the U16 (no narrowing + cast), so the canonical bound is direct from `hwbnd`. -/ +theorem Canonical_neg_pure (a : parameters.FieldElement) + (ha : Canonical a) : + Canonical (FieldElement.neg_pure a) := by + have hneg : parameters.FieldElement.neg a = .ok (FieldElement.neg_pure a) := + FieldElement.neg_eq_ok a ha + unfold Canonical at ha + unfold parameters.FIELD_MODULUS at ha + simp at ha + unfold parameters.FieldElement.neg at hneg + have hA := a.val.hBounds + simp [Std.UScalarTy.numBits] at hA + have hqval : (parameters.FIELD_MODULUS : Std.U16).val = 3329 := by + unfold parameters.FIELD_MODULUS; simp + have hae := Std.UScalar.sub_equiv (parameters.FIELD_MODULUS : Std.U16) a.val + cases hqa : ((parameters.FIELD_MODULUS : Std.U16) - a.val : Result Std.U16) with + | ok i => + rw [hqa] at hae hneg; simp at hae + obtain ⟨_, _, _⟩ := hae + simp only [bind_tc_ok] at hneg + have hq_ne : (parameters.FIELD_MODULUS : Std.U16).val ≠ 0 := by + rw [hqval]; decide + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U16 i parameters.FIELD_MODULUS hq_ne + rw [hw_eq] at hneg; simp only [bind_tc_ok] at hneg + unfold parameters.FieldElement.new at hneg + simp at hneg + have hwbnd : w.val < 3329 := by + rw [hwval, hqval]; exact Nat.mod_lt _ (by decide) + unfold Canonical + rw [← hneg] + show w.val < parameters.FIELD_MODULUS.val + unfold parameters.FIELD_MODULUS + simp + exact hwbnd + | fail e => + rw [hqa] at hae; simp at hae + rw [hqval] at hae + omega + | div => rw [hqa] at hae; exact hae.elim + +/-! ## Poly-wrapper pure-projection side lemmas. + + Each `polynomial._eq_ok` pins the hacspec wrapper's `.ok` value to + the projected `polynomial._pure` value, on the corresponding + canonicity precondition (none for `add`, none for `poly_barrett_reduce`, + per-element canonicity for both inputs to `subtract_reduce`). These + are the input lemmas to `libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq` and to downstream M.2 + commute lemmas; + for why "panic-freedom" is the wrong framing. Proof shape: + * Unfold the wrapper through `parameters.createi` to + `core.array.from_fn`. + * Apply `libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq` with `f` the + pointwise pure projection. + * Discharge the per-element closure obligation by unfolding + `call_mut`/`call`, rewriting the leading `Array.index_usize` via + `Std.Array.index_usize_spec`, and matching the rest of the body + against the corresponding scalar `_eq_ok` lemma. + * Conclude by unfolding the `_pure` projection and rewriting through + the just-proved equation, reducing the match on `.ok`. -/ + +/-- Local helper: `Std.Array.index_usize` on a length-`n` array at index + `i < n` returns `.ok v.val[i.val]!`. Mirrors `array_index_usize_ok_eq` + in `Util.PortableVector` but avoids pulling that file's + LoopSpecs/PortableVector chain through `SpecPure`. -/ +private theorem array_index_usize_ok + {α : Type u} {n : Std.Usize} [Inhabited α] + (v : Std.Array α n) (i : Std.Usize) (h_bd : i.val < v.length) : + Aeneas.Std.Array.index_usize v i = .ok (v.val[i.val]!) := by + have hT := Aeneas.Std.Array.index_usize_spec v i h_bd + have h_ex := Aeneas.Std.WP.spec_imp_exists hT + obtain ⟨v', hveq, hPv'⟩ := h_ex + rw [hveq, hPv', getElem!_pos] + +/-- Pure-projection side lemma for `polynomial.add_to_ring_element` — + unconditional over ALL inputs. + + Proof: unfold wrapper to `core.array.from_fn`; apply + `from_fn_pure_eq` with the pointwise `FieldElement.add_pure`; the + pointwise closure body is the inlined `FieldElement.add` body + preceded by two `index_usize` ops — closed by `FieldElement.add_eq_ok`. -/ +theorem polynomial.add_to_ring_element_eq_ok + (lhs rhs : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + hacspec_ml_kem.polynomial.add_to_ring_element lhs rhs + = .ok (polynomial.add_to_ring_element_pure lhs rhs) := by + -- Step 1: pointwise function f. + set f : Nat → parameters.FieldElement := + fun k => FieldElement.add_pure (lhs.val[k]!) (rhs.val[k]!) with hf_def + -- Step 2: pointwise closure obligation. + have hpure : ∀ k : Nat, k < (256#usize : Std.Usize).val → + (hacspec_ml_kem.polynomial.add_to_ring_element.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + : CoreModels.core.ops.function.Fn _ _ _).FnMutInst.call_mut + (lhs, rhs) ⟨BitVec.ofNat _ k⟩ + = .ok (f k, (lhs, rhs)) := by + intro k hk + have hk' : k < 256 := hk + show polynomial.add_to_ring_element.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + (lhs, rhs) ⟨BitVec.ofNat _ k⟩ = .ok (f k, (lhs, rhs)) + unfold polynomial.add_to_ring_element.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + unfold polynomial.add_to_ring_element.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement.call + -- Index-usize obligations. + have hk_us : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + have : k < 2^System.Platform.numBits := by + have hbits : 2^16 ≤ 2^System.Platform.numBits := + Nat.pow_le_pow_right (by decide) (by + cases System.Platform.numBits_eq with + | inl h => rw [h]; decide + | inr h => rw [h]; decide) + omega + exact this + have hlhs_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < lhs.length := by + rw [hk_us]; show k < lhs.val.length + rw [lhs.property]; exact hk + have hrhs_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < rhs.length := by + rw [hk_us]; show k < rhs.val.length + rw [rhs.property]; exact hk + have h_lhs_idx : + Std.Array.index_usize lhs (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (lhs.val[k]!) := by + rw [array_index_usize_ok lhs _ hlhs_len, hk_us] + have h_rhs_idx : + Std.Array.index_usize rhs (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (rhs.val[k]!) := by + rw [array_index_usize_ok rhs _ hrhs_len, hk_us] + -- The remainder of the closure body is exactly the body of + -- `parameters.FieldElement.add lhs[k]! rhs[k]!`. + have h_add := + FieldElement.add_eq_ok (lhs.val[k]!) (rhs.val[k]!) + -- The outer wrapper has shape `do let fe ← (let (a, a1) := (lhs, rhs); inner_call); ok (fe, ...)`. + -- Reduce the destructuring `let (a, a1) := (lhs, rhs)` to `a := lhs, a1 := rhs`. + change (do + let fe ← (do + let fe ← Std.Array.index_usize lhs ⟨BitVec.ofNat _ k⟩ + let i ← lift (Std.UScalar.cast .U32 fe.val) + let fe1 ← Std.Array.index_usize rhs ⟨BitVec.ofNat _ k⟩ + let i1 ← lift (Std.UScalar.cast .U32 fe1.val) + let i2 ← i + i1 + let i3 ← lift (Std.UScalar.cast .U32 parameters.FIELD_MODULUS) + let i4 ← i2 % i3 + let i5 ← lift (Std.UScalar.cast .U16 i4) + parameters.FieldElement.new i5) + Result.ok (fe, lhs, rhs)) = Result.ok (f k, lhs, rhs) + -- Now rewrite the two `index_usize`s and use `add_eq_ok` to collapse the rest. + rw [h_lhs_idx]; simp only [bind_tc_ok] + rw [h_rhs_idx]; simp only [bind_tc_ok] + unfold parameters.FieldElement.add at h_add + rw [h_add] + simp only [bind_tc_ok, hf_def] + -- Step 3: chain through `from_fn_pure_eq` to get the wrapper equation. + have h_from_fn := + libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq + (T := parameters.FieldElement) + (F := polynomial.add_to_ring_element.closure) + (N := 256#usize) + (inst := polynomial.add_to_ring_element.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement) + (c := (lhs, rhs)) + (f := f) + hpure + have h_wrap : hacspec_ml_kem.polynomial.add_to_ring_element lhs rhs + = .ok ⟨(List.range (256#usize : Std.Usize).val).map f, + by simp [List.length_map, List.length_range]⟩ := by + unfold hacspec_ml_kem.polynomial.add_to_ring_element + unfold hacspec_ml_kem.parameters.createi + exact h_from_fn + -- Step 4: reduce the `_pure` projection via the wrapper equation. + rw [h_wrap] + unfold polynomial.add_to_ring_element_pure + rw [h_wrap] + +/-- The pure-rem of a U16 by `parameters.FIELD_MODULUS` (= 3329 ≠ 0). + A noncomputable wrapper extracting the `.ok` witness from + `uscalar_rem_ok_U16`. Used as the pointwise function in + `poly_barrett_reduce_eq_ok`. -/ +private noncomputable def rem_q_U16 (z : Std.U16) : Std.U16 := + have hq_ne : (parameters.FIELD_MODULUS : Std.U16).val ≠ 0 := by + unfold parameters.FIELD_MODULUS; decide + Classical.choose (uscalar_rem_ok_U16 z parameters.FIELD_MODULUS hq_ne) + +private theorem rem_q_U16_eq (z : Std.U16) : + (z % parameters.FIELD_MODULUS : Result Std.U16) = .ok (rem_q_U16 z) := by + have hq_ne : (parameters.FIELD_MODULUS : Std.U16).val ≠ 0 := by + unfold parameters.FIELD_MODULUS; decide + unfold rem_q_U16 + exact (Classical.choose_spec + (uscalar_rem_ok_U16 z parameters.FIELD_MODULUS hq_ne)).1 + +/-- Pure-projection side lemma for `polynomial.poly_barrett_reduce` — + unconditional over ALL inputs. + + Proof: unfold wrapper to `core.array.from_fn`; apply + `from_fn_pure_eq` with `f k := ⟨rem_q_U16 (p.val[k]!).val⟩`. The + pointwise closure body is `index_usize p k; (fe.val % q); new` — the + `%` step is at U16 width (no widening) and `parameters.FIELD_MODULUS + ≠ 0`, so `uscalar_rem_ok_U16` discharges it. -/ +theorem polynomial.poly_barrett_reduce_eq_ok + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) : + hacspec_ml_kem.polynomial.poly_barrett_reduce p + = .ok (polynomial.poly_barrett_reduce_pure p) := by + -- Step 1: pointwise function f. + set f : Nat → parameters.FieldElement := + fun k => { val := rem_q_U16 (p.val[k]!).val } + with hf_def + -- Step 2: pointwise closure obligation. + have hpure : ∀ k : Nat, k < (256#usize : Std.Usize).val → + (hacspec_ml_kem.polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + : CoreModels.core.ops.function.Fn _ _ _).FnMutInst.call_mut + p ⟨BitVec.ofNat _ k⟩ + = .ok (f k, p) := by + intro k hk + show polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + p ⟨BitVec.ofNat _ k⟩ = .ok (f k, p) + unfold polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + unfold polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement.call + -- Index-usize obligation. + have hk' : k < 256 := hk + have hk_us : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + have : k < 2^System.Platform.numBits := by + have hbits : 2^16 ≤ 2^System.Platform.numBits := + Nat.pow_le_pow_right (by decide) (by + cases System.Platform.numBits_eq with + | inl h => rw [h]; decide + | inr h => rw [h]; decide) + omega + exact this + have hp_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < p.length := by + rw [hk_us]; show k < p.val.length + rw [p.property]; exact hk + have h_p_idx : + Std.Array.index_usize p (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (p.val[k]!) := by + rw [array_index_usize_ok p _ hp_len, hk_us] + -- Close the closure body: index, then rem, then new (returned inline). + change (do + let fe ← (do + let fe ← Std.Array.index_usize p ⟨BitVec.ofNat _ k⟩ + let i ← (fe.val % parameters.FIELD_MODULUS : Result Std.U16) + parameters.FieldElement.new i) + Result.ok (fe, p)) = Result.ok (f k, p) + rw [h_p_idx]; simp only [bind_tc_ok] + rw [rem_q_U16_eq]; simp only [bind_tc_ok] + unfold parameters.FieldElement.new + simp only [bind_tc_ok, hf_def] + -- Step 3: apply from_fn_pure_eq. + have h_from_fn := + libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq + (T := parameters.FieldElement) + (F := polynomial.poly_barrett_reduce.closure) + (N := 256#usize) + (inst := polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement) + (c := p) + (f := f) + hpure + have h_wrap : hacspec_ml_kem.polynomial.poly_barrett_reduce p + = .ok ⟨(List.range (256#usize : Std.Usize).val).map f, + by simp [List.length_map, List.length_range]⟩ := by + unfold hacspec_ml_kem.polynomial.poly_barrett_reduce + unfold hacspec_ml_kem.parameters.createi + exact h_from_fn + -- Step 4: reduce the `_pure` projection via h_wrap. + rw [h_wrap] + unfold polynomial.poly_barrett_reduce_pure + rw [h_wrap] + +/-- Identity-on-canonical bridge for `polynomial.poly_barrett_reduce_pure`. + + When every lane of `p` is canonical (`p.val[k]!.val.val < q`), the pure + projection is the identity: `poly_barrett_reduce_pure p = p`. Used by + L6.1 FC close where the input is `lift_poly self` (canonical by + `lift_fe`'s `feOfZMod` codomain). -/ +theorem polynomial.poly_barrett_reduce_pure_id_of_canonical + (p : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (hcan : ∀ k : Nat, k < 256 → Canonical (p.val[k]!)) : + polynomial.poly_barrett_reduce_pure p = p := by + -- Re-derive `h_wrap` with f := fun k => p.val[k]! (canonical identity). + set f : Nat → parameters.FieldElement := fun k => p.val[k]! with hf_def + have hpure : ∀ k : Nat, k < (256#usize : Std.Usize).val → + (hacspec_ml_kem.polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + : CoreModels.core.ops.function.Fn _ _ _).FnMutInst.call_mut + p ⟨BitVec.ofNat _ k⟩ + = .ok (f k, p) := by + intro k hk + show polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + p ⟨BitVec.ofNat _ k⟩ = .ok (f k, p) + unfold polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + unfold polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement.call + have hk' : k < 256 := hk + have hk_us : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + have : k < 2^System.Platform.numBits := by + have hbits : 2^16 ≤ 2^System.Platform.numBits := + Nat.pow_le_pow_right (by decide) (by + cases System.Platform.numBits_eq with + | inl h => rw [h]; decide + | inr h => rw [h]; decide) + omega + exact this + have hp_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < p.length := by + rw [hk_us]; show k < p.val.length + rw [p.property]; exact hk + have h_p_idx : + Std.Array.index_usize p (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (p.val[k]!) := by + rw [array_index_usize_ok p _ hp_len, hk_us] + change (do + let fe ← (do + let fe ← Std.Array.index_usize p ⟨BitVec.ofNat _ k⟩ + let i ← (fe.val % parameters.FIELD_MODULUS : Result Std.U16) + parameters.FieldElement.new i) + Result.ok (fe, p)) = Result.ok (f k, p) + rw [h_p_idx]; simp only [bind_tc_ok] + rw [rem_q_U16_eq]; simp only [bind_tc_ok] + unfold parameters.FieldElement.new + simp only [bind_tc_ok] + -- Goal: .ok ({ val := rem_q_U16 (p.val[k]!).val }, p) = .ok (f k, p). + -- Build f k = p.val[k]! identity using canonicity. + have hcank : Canonical (p.val[k]!) := hcan k hk' + unfold Canonical at hcank + have hq_val : parameters.FIELD_MODULUS.val = 3329 := by + unfold parameters.FIELD_MODULUS; decide + have hcank_int : (p.val[k]!).val.val < 3329 := by + have : (p.val[k]!).val.val < parameters.FIELD_MODULUS.val := hcank + rw [hq_val] at this; exact this + -- rem_q_U16 (p.val[k]!).val .val = (p.val[k]!).val.val % 3329 = (p.val[k]!).val.val. + have hq_ne : (parameters.FIELD_MODULUS : Std.U16).val ≠ 0 := by + unfold parameters.FIELD_MODULUS; decide + have h_rem_val : (rem_q_U16 (p.val[k]!).val).val = (p.val[k]!).val.val % 3329 := by + have ⟨w, hw_eq, hw_val⟩ := + uscalar_rem_ok_U16 (p.val[k]!).val parameters.FIELD_MODULUS hq_ne + have h_rem_eq := rem_q_U16_eq (p.val[k]!).val + rw [hw_eq] at h_rem_eq + have h_w_eq_rem : w = rem_q_U16 (p.val[k]!).val := Result.ok.inj h_rem_eq + rw [← h_w_eq_rem, hw_val, hq_val] + have h_rem_val_eq : (rem_q_U16 (p.val[k]!).val).val = (p.val[k]!).val.val := + h_rem_val.trans (Nat.mod_eq_of_lt hcank_int) + -- The two U16s `rem_q_U16 (p.val[k]!).val` and `(p.val[k]!).val` have equal .val. + -- Use Std.U16's @[ext] from .val equality, which goes via .bv equality. + have h_u16_eq : rem_q_U16 (p.val[k]!).val = (p.val[k]!).val := by + apply Std.U16.bv_eq_imp_eq + -- .bv equality reduces to .val (= .bv.toNat) equality via BitVec.ext + width. + show (rem_q_U16 (p.val[k]!).val).bv = ((p.val[k]!).val).bv + apply BitVec.eq_of_toNat_eq + show (rem_q_U16 (p.val[k]!).val).val = ((p.val[k]!).val).val + exact h_rem_val_eq + -- Plug in: ⟨rem_q_U16 (p.val[k]!).val⟩ = ⟨(p.val[k]!).val⟩ = p.val[k]! = f k. + have h_fe_eq : ({ val := rem_q_U16 (p.val[k]!).val } : parameters.FieldElement) = f k := by + rw [h_u16_eq, hf_def] + rw [h_fe_eq] + -- Apply from_fn_pure_eq with this f. + have h_from_fn := + libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq + (T := parameters.FieldElement) + (F := polynomial.poly_barrett_reduce.closure) + (N := 256#usize) + (inst := polynomial.poly_barrett_reduce.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement) + (c := p) + (f := f) + hpure + have h_wrap : hacspec_ml_kem.polynomial.poly_barrett_reduce p + = .ok ⟨(List.range (256#usize : Std.Usize).val).map f, + by simp [List.length_map, List.length_range]⟩ := by + unfold hacspec_ml_kem.polynomial.poly_barrett_reduce + unfold hacspec_ml_kem.parameters.createi + exact h_from_fn + -- Unfold _pure via h_wrap. + unfold polynomial.poly_barrett_reduce_pure + rw [h_wrap] + -- Goal: ⟨(range 256).map f, _⟩ = p. Use Subtype.ext + list equality. + apply Subtype.ext + show (List.range 256).map f = p.val + have h_p_len : p.val.length = 256 := p.property + apply List.ext_getElem + · simp [h_p_len] + · intro k hk1 _hk2 + have hk : k < 256 := by + have : k < (List.range 256).length := by simpa using hk1 + simpa using this + rw [List.getElem_map, List.getElem_range] + show f k = p.val[k] + rw [hf_def] + show p.val[k]! = p.val[k] + -- Use getElem!_pos to align p.val[k]! with p.val[k]'_. + exact getElem!_pos p.val k (by rw [h_p_len]; exact hk) + +/-- Pure-projection side lemma for `polynomial.subtract_reduce` — valid + for per-element CANONICAL inputs. The closure body inlines + `parameters.FieldElement.sub`'s + do-block (with a moved `index_usize a1 k` in the middle); after + rewriting the two `index_usize` calls to `.ok` form, the remaining + body IS the body of `parameters.FieldElement.sub (a.val[k]!) (a1.val[k]!)` + so the per-element canonicity preconditions feed directly into + `FieldElement.sub_eq_ok`. -/ +theorem polynomial.subtract_reduce_eq_ok + (a b : Std.Array hacspec_ml_kem.parameters.FieldElement 256#usize) + (ha : ∀ k : Nat, k < 256 → Canonical (a.val[k]!)) + (hb : ∀ k : Nat, k < 256 → Canonical (b.val[k]!)) : + hacspec_ml_kem.polynomial.subtract_reduce a b + = .ok (polynomial.subtract_reduce_pure a b) := by + -- Step 1: pointwise function f. + set f : Nat → parameters.FieldElement := + fun k => FieldElement.sub_pure (a.val[k]!) (b.val[k]!) with hf_def + -- Step 2: pointwise closure obligation. + have hpure : ∀ k : Nat, k < (256#usize : Std.Usize).val → + (hacspec_ml_kem.polynomial.subtract_reduce.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement + : CoreModels.core.ops.function.Fn _ _ _).FnMutInst.call_mut + (a, b) ⟨BitVec.ofNat _ k⟩ + = .ok (f k, (a, b)) := by + intro k hk + have hk' : k < 256 := hk + show polynomial.subtract_reduce.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + (a, b) ⟨BitVec.ofNat _ k⟩ = .ok (f k, (a, b)) + unfold polynomial.subtract_reduce.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement.call_mut + unfold polynomial.subtract_reduce.closure.Insts.CoreOpsFunctionFnTupleUsizeFieldElement.call + -- Index-usize obligations. + have hk_us : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val = k := by + show (BitVec.ofNat _ k).toNat = k + apply Nat.mod_eq_of_lt + have : k < 2^System.Platform.numBits := by + have hbits : 2^16 ≤ 2^System.Platform.numBits := + Nat.pow_le_pow_right (by decide) (by + cases System.Platform.numBits_eq with + | inl h => rw [h]; decide + | inr h => rw [h]; decide) + omega + exact this + have ha_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < a.length := by + rw [hk_us]; show k < a.val.length + rw [a.property]; exact hk + have hb_len : (⟨BitVec.ofNat _ k⟩ : Std.Usize).val < b.length := by + rw [hk_us]; show k < b.val.length + rw [b.property]; exact hk + have h_a_idx : + Std.Array.index_usize a (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (a.val[k]!) := by + rw [array_index_usize_ok a _ ha_len, hk_us] + have h_b_idx : + Std.Array.index_usize b (⟨BitVec.ofNat _ k⟩ : Std.Usize) + = .ok (b.val[k]!) := by + rw [array_index_usize_ok b _ hb_len, hk_us] + -- The `sub_eq_ok` lemma needs canonicity of both operands. + have h_sub := + FieldElement.sub_eq_ok (a.val[k]!) (b.val[k]!) (ha k hk') (hb k hk') + unfold parameters.FieldElement.sub at h_sub + -- The outer wrapper has shape `do let fe ← (let (x, y) := (a, b); inner_call); ok (fe, ...)`. + -- Reduce the destructuring `let (x, y) := (a, b)`. + change (do + let fe ← (do + let fe ← Std.Array.index_usize a ⟨BitVec.ofNat _ k⟩ + let i ← lift (Std.UScalar.cast .U32 fe.val) + let i1 ← lift (Std.UScalar.cast .U32 parameters.FIELD_MODULUS) + let i2 ← i + i1 + let fe1 ← Std.Array.index_usize b ⟨BitVec.ofNat _ k⟩ + let i3 ← lift (Std.UScalar.cast .U32 fe1.val) + let i4 ← i2 - i3 + let i5 ← lift (Std.UScalar.cast .U32 parameters.FIELD_MODULUS) + let i6 ← i4 % i5 + let i7 ← lift (Std.UScalar.cast .U16 i6) + parameters.FieldElement.new i7) + Result.ok (fe, a, b)) = Result.ok (f k, a, b) + rw [h_a_idx]; simp only [bind_tc_ok] + rw [h_b_idx]; simp only [bind_tc_ok] + -- The inner block is now exactly the body of + -- `parameters.FieldElement.sub (a.val[k]!) (b.val[k]!)` (with `b` re-ordered), + -- which equals `.ok (sub_pure …)` by `h_sub`. + rw [h_sub] + simp only [bind_tc_ok, hf_def] + -- Step 3: apply from_fn_pure_eq. + have h_from_fn := + libcrux_iot_ml_kem.Util.CreateI.from_fn_pure_eq + (T := parameters.FieldElement) + (F := polynomial.subtract_reduce.closure) + (N := 256#usize) + (inst := polynomial.subtract_reduce.closure.Insts.CoreOpsFunctionFnMutTupleUsizeFieldElement) + (c := (a, b)) + (f := f) + hpure + have h_wrap : hacspec_ml_kem.polynomial.subtract_reduce a b + = .ok ⟨(List.range (256#usize : Std.Usize).val).map f, + by simp [List.length_map, List.length_range]⟩ := by + unfold hacspec_ml_kem.polynomial.subtract_reduce + unfold hacspec_ml_kem.parameters.createi + exact h_from_fn + -- Step 4: reduce the `_pure` projection via h_wrap. + rw [h_wrap] + unfold polynomial.subtract_reduce_pure + rw [h_wrap] + +end libcrux_iot_ml_kem.Spec.Pure \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/StateIso.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/StateIso.lean new file mode 100644 index 00000000..edfbe87f --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Spec/StateIso.lean @@ -0,0 +1,305 @@ +/- + # `Spec/StateIso.lean` — M.3 impl ↔ MontPoly round-trip lemmas. + + Companion to `Spec.lean` (M.1) and `Spec/Commute.lean` (M.2). + This file ships the **Decision Point I.6** injectivity pair plus an + `lift_id` round-trip showing `to_spec_poly_plain (canonicalUnlift m) = m` + for any `m : MontPoly`. + + ## Theorems landed (this file) + + - `to_spec_poly_mont_injective_canonical` — under the bound + `|.val| < 1665` per lane, `to_spec_poly_mont` injective on Ints. + - `to_spec_poly_mont_extended` — the functorial direction: + per-lane equality of `i16_to_spec_fe_mont` ⇔ `to_spec_poly_mont` + equality (the `←` direction; the `→` direction is M.1's + `lemma_to_spec_poly_mont_eq_of_coeffs`). + - `canonicalUnlift` + `to_spec_poly_plain_canonicalUnlift` — option (a) + of the arch plan: nonneg `Fin 3329` representative carried through + `Std.I16.ofIntCore`, then back through the plain lift gives `id`. + - `to_spec_poly_mont_of_zero` — the zero-PolynomialRingElement lifts + to the all-zeros `MontPoly`. + + ## Decision Point I.6 — bound choice (arch plan §D.1) + + Canonical injectivity uses `|lane.val| < 1665`. Each `ZMod 3329` + element has a unique Int representative in `(-1664, 1664]` once we + restrict to that bound — `i16_to_spec_fe_mont` is then a bijection + with the integer lane value (after stripping the `· 169` factor). + + The cancellation step needs `169 · 2285 ≡ 1 (mod 3329)` (the inverse + of `R⁻¹ = 169` is `R = 2285` in `ZMod 3329`; this is exactly + `mont_R_inv_q` from `Util/NumericKeystones` after the `2^16 ↦ 2285` + conversion in B.3). We pull it as `mul_169_2285_eq_one` below. + + ## Discipline + + - No `@[scoped grind]` — these are one-shot lemmas for downstream + explicit use, not a grind set. + - No `sorry`, no `admit`. + Mathlib is imported here for `ZMod 3329` injectivity arguments + (cancel `169` via the explicit inverse). +-/ +import LibcruxIotMlKem.Spec +import LibcruxIotMlKem.Spec.NumericKeystones +import Mathlib.Data.ZMod.Basic +import Mathlib.Tactic.Ring +import Mathlib.Tactic.FieldSimp + +namespace libcrux_iot_ml_kem.Spec.StateIso +open CoreModels Aeneas Aeneas.Std +open libcrux_iot_ml_kem.Spec + +/-! ### Local `Inhabited` instances (mirror of `Spec.lean`). -/ + +local instance instInhabitedPortableVector_stateIso : + Inhabited libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + ⟨{ elements := Std.Array.make 16#usize (List.replicate 16 (0#i16 : Std.I16)) + (by simp) }⟩ + +local instance instInhabitedPolynomialRingElement_stateIso + {Vector : Type} [Inhabited Vector] : + Inhabited (libcrux_iot_ml_kem.polynomial.PolynomialRingElement Vector) := + ⟨{ coefficients := Std.Array.make 16#usize (List.replicate 16 default) (by simp) }⟩ + +/-! ## §D.1 Inverse keystone — `169 · 2285 ≡ 1 (mod 3329)`. -/ + +/-- The Mont-factor inverse identity: `(169 : ZMod 3329) * 2285 = 1`. + Used to cancel the `· 169` factor in `i16_to_spec_fe_mont` when + proving canonical injectivity. -/ +theorem mul_169_2285_eq_one : (169 : ZMod 3329) * 2285 = 1 := by decide + +/-- `169` is nonzero in `ZMod 3329`. -/ +theorem ne_zero_169 : (169 : ZMod 3329) ≠ 0 := by decide + +/-- Cancellation lemma: if `a * 169 = b * 169` in `ZMod 3329`, then + `a = b`. -/ +theorem mul_169_cancel {a b : ZMod 3329} (h : a * 169 = b * 169) : a = b := by + -- Multiply both sides on the right by 2285, then use the inverse identity. + have := congrArg (· * 2285) h + simp only [mul_assoc, mul_169_2285_eq_one, mul_one] at this + exact this + +/-! ## §D.1 — extended (functorial) equivalence. -/ + +/-- **Extended functorial equivalence.** `to_spec_poly_mont` equality + is equivalent to lane-by-lane equality of `i16_to_spec_fe_mont` + (i.e. of `(.val : ZMod 3329) * 169`). The `→` direction is the + point — `lemma_to_spec_poly_mont_eq_of_coeffs` from M.1 already + gives the `←` direction, and this is its converse. -/ +theorem to_spec_poly_mont_extended + (re re' : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_eq : to_spec_poly_mont re = to_spec_poly_mont re') : + ∀ i j : Fin 16, + i16_to_spec_fe_mont + ((re.coefficients.val[i.val]!).elements.val[j.val]!) + = i16_to_spec_fe_mont + ((re'.coefficients.val[i.val]!).elements.val[j.val]!) := by + intro i j + -- Apply M.1's unfold lemma at index 16*i + j on both sides. + have h_at := congrArg (fun v : MontPoly => v[16 * i.val + j.val]'(by + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + omega)) h_eq + simp only at h_at + rw [lemma_to_spec_poly_mont_unfold, lemma_to_spec_poly_mont_unfold] at h_at + exact h_at + +/-! ## §D.1 — canonical injectivity. -/ + +/-- **Canonical-bounded injectivity.** If both polys have all lanes + with `|.val|` bounded by `< 1665`, then `to_spec_poly_mont` + equality lifts to lane-by-lane `Int`-equality. + + Proof shape: + 1. Drop to per-lane `i16_to_spec_fe_mont` equality via + `to_spec_poly_mont_extended`. + 2. Cancel `· 169` via `mul_169_cancel`. + 3. Use the `< 1665` bound to lift `ZMod 3329` equality to + `Int`-equality (each ZMod 3329 has a unique canonical + representative in `(-1664, 1664]`). -/ +theorem to_spec_poly_mont_injective_canonical + (re re' : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_re : ∀ i j : Fin 16, + (re.coefficients.val[i.val]!.elements.val[j.val]!).val.natAbs < 1665) + (h_re' : ∀ i j : Fin 16, + (re'.coefficients.val[i.val]!.elements.val[j.val]!).val.natAbs < 1665) + (h_eq : to_spec_poly_mont re = to_spec_poly_mont re') : + ∀ i j : Fin 16, + (re.coefficients.val[i.val]!.elements.val[j.val]!).val = + (re'.coefficients.val[i.val]!.elements.val[j.val]!).val := by + intro i j + -- Step 1: per-lane Mont-lift equality. + have h_lane := to_spec_poly_mont_extended re re' h_eq i j + -- Step 2: unfold + cancel `· 169`. + rw [i16_to_spec_fe_mont_unfold, i16_to_spec_fe_mont_unfold] at h_lane + have h_zmod := mul_169_cancel h_lane + -- Step 3: cast back to Int. We have: + -- (a.val : ZMod 3329) = (b.val : ZMod 3329) with |a.val|, |b.val| < 1665. + -- This forces a.val = b.val on Int because the canonical signed + -- representative of any ZMod 3329 element in (-1664, 1664] is unique. + set a := (re.coefficients.val[i.val]!.elements.val[j.val]!).val with ha_def + set b := (re'.coefficients.val[i.val]!.elements.val[j.val]!).val with hb_def + have ha : a.natAbs < 1665 := h_re i j + have hb : b.natAbs < 1665 := h_re' i j + -- `(a : ZMod 3329) = (b : ZMod 3329)` ⇔ `a ≡ b [ZMOD 3329]`. + rw [ZMod.intCast_eq_intCast_iff a b 3329] at h_zmod + -- From the bound: |a - b| < 3330, and 3329 ∣ (a - b), so a - b = 0. + have h_diff_bound : (a - b).natAbs < 3330 := by + have : (a - b).natAbs ≤ a.natAbs + b.natAbs := Int.natAbs_sub_le a b + omega + -- `Int.ModEq.dvd` gives `3329 ∣ b - a`; we want `a - b`. + have h_div : (3329 : Int) ∣ (a - b) := by + have h := h_zmod.dvd + rwa [show (b - a) = -(a - b) by ring, dvd_neg] at h + -- A multiple of 3329 with absolute value < 3330 is 0. + have h_zero : a - b = 0 := by + rcases h_div with ⟨k, hk⟩ + by_contra hne + have hk_ne : k ≠ 0 := by + rintro rfl + simp at hk + exact hne hk + have hk_abs : k.natAbs ≥ 1 := Int.natAbs_pos.mpr hk_ne + have : (a - b).natAbs = 3329 * k.natAbs := by + rw [hk, Int.natAbs_mul]; rfl + omega + omega + +/-! ## §D.3 — `canonicalUnlift` and `lift_id` (option (a)). -/ + +/-- Helper: cast a `ZMod 3329` element to its canonical nonneg + `Std.I16` representative. `z.val < 3329 < 32768 = 2^15`, so the + `Std.I16.ofIntCore` bounds-check goal is `decide`. -/ +def i16OfZMod (z : ZMod 3329) : Std.I16 := + Std.I16.ofIntCore (z.val : Int) (by + refine ⟨?_, ?_⟩ + · have h_neg : -(2:Int)^(IScalarTy.I16.numBits-1) ≤ 0 := by decide + have h_nn : (0 : Int) ≤ (z.val : Int) := Int.natCast_nonneg _ + linarith + · have h_lt : z.val < 3329 := ZMod.val_lt _ + have h_lt_int : (z.val : Int) < 3329 := Int.ofNat_lt.mpr h_lt + have h_bd : (3329 : Int) < (2:Int)^(IScalarTy.I16.numBits-1) := by decide + linarith) + +/-- `Std.I16.val (i16OfZMod z) = z.val` on the `Int` side. -/ +theorem i16OfZMod_val (z : ZMod 3329) : (i16OfZMod z).val = (z.val : Int) := by + unfold i16OfZMod + exact I16.ofInt_val_eq _ + +/-- `canonicalUnlift : MontPoly → PolynomialRingElement PortableVector`. + Each lane becomes the nonneg `Fin 3329` representative cast to + `Std.I16` via `i16OfZMod`. + + Used by `to_spec_poly_plain_canonicalUnlift` below. Note: this is + matched against the **PLAIN** lift, not the Mont lift — + `canonicalUnlift` does NOT preinject a Mont factor, so feeding it + through `to_spec_poly_mont` would multiply by `169` (off by R⁻¹). + M.4 may want a `canonicalUnliftMont` variant; out of scope here. -/ +def canonicalUnlift (m : MontPoly) : + libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { coefficients := Std.Array.make 16#usize + (List.ofFn (n := 16) fun i => + { elements := Std.Array.make 16#usize + (List.ofFn (n := 16) fun j => + i16OfZMod (m[16 * i.val + j.val]'(by + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + omega))) + (by simp) }) + (by simp) } + +/-- Helper: indexing into `canonicalUnlift m`'s coefficients-chunk + at position `i < 16` yields the inner chunk with elements + `i16OfZMod m[16*i + j]` for `j < 16`. Drives + `to_spec_poly_plain_canonicalUnlift`. -/ +theorem canonicalUnlift_chunk (m : MontPoly) (i : Fin 16) : + (canonicalUnlift m).coefficients.val[i.val]! = + { elements := Std.Array.make 16#usize + (List.ofFn (n := 16) fun j : Fin 16 => + i16OfZMod (m[16 * i.val + j.val]'(by + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + omega))) + (by simp) } := by + unfold canonicalUnlift + show (List.ofFn _)[i.val]! = _ + rw [getElem!_pos _ _ (by rw [List.length_ofFn]; exact i.isLt), + List.getElem_ofFn] + +/-- Helper: indexing into the chunk at lane `j < 16` yields the + `i16OfZMod` lane. -/ +theorem canonicalUnlift_lane (m : MontPoly) (i j : Fin 16) : + ((canonicalUnlift m).coefficients.val[i.val]!.elements.val[j.val]!) = + i16OfZMod (m[16 * i.val + j.val]'(by + have hi : i.val < 16 := i.isLt + have hj : j.val < 16 := j.isLt + omega)) := by + rw [canonicalUnlift_chunk] + show (List.ofFn _)[j.val]! = _ + rw [getElem!_pos _ _ (by rw [List.length_ofFn]; exact j.isLt), + List.getElem_ofFn] + +/-- **`lift_id` (option (a)).** After `canonicalUnlift` and re-lifting + through `to_spec_poly_plain`, we recover the original `MontPoly`. -/ +theorem to_spec_poly_plain_canonicalUnlift (m : MontPoly) : + to_spec_poly_plain (canonicalUnlift m) = m := by + apply Vector.ext + intro k hk + unfold to_spec_poly_plain + simp only [Vector.getElem_ofFn] + have hdiv_lt : k / 16 < 16 := by omega + have hmod_lt : k % 16 < 16 := Nat.mod_lt k (by decide) + have hk_eq : k = 16 * (k / 16) + (k % 16) := by omega + -- Apply the helper at i = ⟨k/16, ...⟩, j = ⟨k%16, ...⟩. + rw [canonicalUnlift_lane m ⟨k / 16, hdiv_lt⟩ ⟨k % 16, hmod_lt⟩] + unfold i16_to_spec_fe_plain + rw [i16OfZMod_val] + -- Goal: (m[16*(k/16) + k%16].val : ZMod 3329) = m[k]. Reduce the + -- index and apply `ZMod.natCast_zmod_val`. + have hkidx : (16 * (k / 16) + k % 16) = k := by omega + -- Use Vector.getElem proof-irrelevance: m at equal indices are equal. + have hget : m[16 * (k / 16) + k % 16]'(by omega) = m[k] := by + -- This is `getElem` proof-irrelevance + index equality. + congr 1 + show ((m[16 * (k / 16) + k % 16]'(by omega)).val : ZMod 3329) = m[k] + rw [hget] + exact_mod_cast ZMod.natCast_zmod_val _ + +/-! ## §D.4 — auxiliary lemma. -/ + +/-- Helper: indexing into a `List.replicate n a` at any in-bound + position returns `a`, even with `getElem!`. -/ +private theorem List_replicate_getElem!_eq {α : Type*} [Inhabited α] + {n : Nat} (a : α) {i : Nat} (hi : i < n) : + (List.replicate n a)[i]! = a := by + rw [getElem!_pos _ _ (by rw [List.length_replicate]; exact hi)] + exact List.getElem_replicate _ + +/-- The zero polynomial (every i16 lane = 0) lifts through + `to_spec_poly_mont` to the all-zeros `MontPoly`. + + Phrased generically so Lean's elaborator does not eagerly unfold + the 16×16 literal — the zero impl is constructed via `Std.Array.make` + against a `List.replicate`-flattened input. -/ +theorem to_spec_poly_mont_of_zero + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_zero : ∀ i j : Fin 16, + ((re.coefficients.val[i.val]!).elements.val[j.val]!) = (0#i16 : Std.I16)) : + to_spec_poly_mont re = Vector.replicate 256 (0 : ZMod 3329) := by + apply Vector.ext + intro k hk + unfold to_spec_poly_mont + simp only [Vector.getElem_ofFn, Vector.getElem_replicate] + have hdiv_lt : k / 16 < 16 := by omega + have hmod_lt : k % 16 < 16 := Nat.mod_lt k (by decide) + rw [h_zero ⟨k / 16, hdiv_lt⟩ ⟨k % 16, hmod_lt⟩] + unfold i16_to_spec_fe_mont + show ((0#i16).val : ZMod 3329) * 169 = 0 + simp + +end libcrux_iot_ml_kem.Spec.StateIso \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Util/CreateI.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Util/CreateI.lean new file mode 100644 index 00000000..9526a74b --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Util/CreateI.lean @@ -0,0 +1,239 @@ +/- + # `Util/CreateI.lean` — Pure-closure `createi` / `from_fn` Triples + + Merges the two pure-closure-array-builder Triples from libcrux-iot + SHA-3's tree (which were split between two unrelated files for + historical reasons) into one Util module with consistent naming. + + ## Lifted from + + - `createi_pure_eq` + `createi_pure_spec` + (`LibcruxIotSha3/Equivalence/HacspecBridge.lean:627,663`) — + the `Fn`-wrapped variant. `createi N inst c` (where `inst` is a + `core.ops.function.Fn` instance) is the hax extraction of + `core::array::from_fn` over a *shared* closure. + + - `from_fn_pure_eq` + `from_fn_pure_spec` + (`LibcruxIotSha3/Sponge/XorBlockSpec.lean:103,137`) — + the *direct* `FnMut` variant. `core.array.from_fn N inst c` + (where `inst` is `core.ops.function.FnMut`) is the hax + extraction of `core::array::from_fn` over a mutable-via-`FnMut` + closure (e.g. `sponge.xor_block_into_state`). + + Both produce a length-`N` array whose `i`-th cell is `f i`, when the + closure body is pure (`call_mut` returns `(f i, c)` for input `i`). + The two specs differ only in the wrapper around `call_mut`. + + Use whichever spec matches the extraction: + - The hax extraction of `core::array::from_fn` over a *`Fn`*-bounded + closure goes through `createi` and the `createi_*_spec` should be + referenced. + - The hax extraction of `core::array::from_fn` over a *`FnMut`*-bounded + closure (or one that internally captures by mutable reference) + goes directly through `core.array.from_fn` and + `from_fn_*_spec` should be referenced. + + ## Naming (renamed for `Util` namespace) + + - `createi_pure_eq` (was `libcrux_iot_sha3.Equivalence.createi_pure_eq`) + - `createi_pure_spec` (was `libcrux_iot_sha3.Equivalence.createi_pure_spec`) + - `from_fn_pure_eq` (was `libcrux_iot_sha3.Sponge.from_fn_pure_eq`) + - `from_fn_pure_spec` (was `libcrux_iot_sha3.Sponge.from_fn_pure_spec`) +-/ +import LibcruxIotMlKem.Util.SliceSpecs +-- `hacspec_sha3.createi` is the canonical `core::array::from_fn` wrapper +-- (extracted from `specs/sha3/src/lib.rs:21`). The same wrapper covers +-- ML-KEM's `Fn`-bounded array builders. Importing brings the symbol +-- into scope; the Triples below are stated on it directly. +import HacspecSha3.Extraction.Funs +import HacspecMlKem.Extraction.Funs + +open CoreModels Aeneas Aeneas.Std Result Std.Do +open hacspec_ml_kem.parameters (createi) + +namespace libcrux_iot_ml_kem.Util.CreateI +open libcrux_iot_ml_kem.Util.SliceSpecs +set_option mvcgen.warning false +set_option linter.unusedVariables false + +/-! ## `Fn`-wrapped variant: `createi N inst c` -/ + +/-- Per-element foldlM evaluation for pure closures. The closure state `c` + is invariant; the result list is `acc ++ l.map f`. -/ +private theorem createi_foldlM_pure_aux + {T F : Type} + (inst : CoreModels.core.ops.function.FnMut F Std.Usize T) (c : F) (f : Nat → T) + (l : List Nat) (acc : List T) + (hpure : ∀ k ∈ l, + inst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c)) : + l.foldlM + (fun (s : List T × F) (i : Nat) => do + let (v, f') ← inst.call_mut s.2 ⟨BitVec.ofNat _ i⟩ + Result.ok (s.1 ++ [v], f')) + (acc, c) = .ok (acc ++ l.map f, c) := by + induction l generalizing acc with + | nil => + simp only [List.foldlM_nil, List.map_nil, List.append_nil] + rfl + | cons h t ih => + have hh : inst.call_mut c ⟨BitVec.ofNat _ h⟩ = .ok (f h, c) := + hpure h List.mem_cons_self + have ht : ∀ k ∈ t, inst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c) := + fun k hk => hpure k (List.mem_cons_of_mem _ hk) + have hih := ih (acc ++ [f h]) ht + simp only [List.foldlM_cons, hh, bind_tc_ok, List.map_cons] + rw [hih] + simp [List.append_assoc] + +/-- Lean-level equation for `createi` over pure closures. Used to power + `createi_pure_spec` (Triple form). -/ +theorem createi_pure_eq + {T F : Type} (N : Std.Usize) + (inst : CoreModels.core.ops.function.Fn F Std.Usize T) (c : F) (f : Nat → T) + (hpure : ∀ k : Nat, k < N.val → + inst.FnMutInst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c)) : + createi N inst c = + .ok ⟨(List.range N.val).map f, + by simp [List.length_map, List.length_range]⟩ := by + have hf : ∀ k ∈ List.range N.val, + inst.FnMutInst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c) := by + intro k hk; exact hpure k (List.mem_range.mp hk) + have h_fold := + createi_foldlM_pure_aux inst.FnMutInst c f (List.range N.val) [] hf + simp only [List.nil_append] at h_fold + unfold createi core.array.from_fn rust_primitives.slice.array_from_fn + split + · rename_i e heq + rw [h_fold] at heq; exact absurd heq (by simp) + · rename_i heq + rw [h_fold] at heq; exact absurd heq (by simp) + · rename_i result heq + rw [h_fold] at heq + have hres : result = ((List.range N.val).map f, c) := + (Result.ok.inj heq).symm + subst hres + rfl + +/-- **Generic pure-closure `[spec]` for `createi`.** + +For any closure whose `call_mut` is pure (doesn't mutate captured state), +`createi N inst c` succeeds and its `i`-th cell is `f i`. The hypothesis +`hpure` is a Triple over each call_mut so `hax_mvcgen` can recurse into +it via per-closure `@[spec]` lemmas. + +Tagged `@[spec]` so `hax_mvcgen` chains through nested `createi` calls. -/ +@[spec] +theorem createi_pure_spec + {T F : Type} [Inhabited T] (N : Std.Usize) + (inst : CoreModels.core.ops.function.Fn F Std.Usize T) (c : F) (f : Nat → T) + (hpure : ∀ k : Nat, k < N.val → + ⦃ ⌜ True ⌝ ⦄ + inst.FnMutInst.call_mut c ⟨BitVec.ofNat _ k⟩ + ⦃ ⇓ r => ⌜ r = (f k, c) ⌝ ⦄) : + ⦃ ⌜ True ⌝ ⦄ + createi N inst c + ⦃ ⇓ a => ⌜ ∀ i : Nat, i < N.val → a.val[i]! = f i ⌝ ⦄ := by + have hpure_eq : ∀ k : Nat, k < N.val → + inst.FnMutInst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c) := + fun k hk => result_eq_of_triple (hpure k hk) + have heq := createi_pure_eq N inst c f hpure_eq + rw [heq] + simp only [Triple, WP.wp] + apply SPred.pure_intro + intro i hi + show ((List.range N.val).map f)[i]! = f i + rw [List.getElem!_eq_getElem?_getD, List.getElem?_map, + List.getElem?_range hi] + rfl + +/-! ## `FnMut`-direct variant: `core.array.from_fn N inst c` + +Analogous to `createi_*` but takes a `core.ops.function.FnMut` +instance directly (no `Fn` wrapper). Required when the hax extraction +calls `core.array.from_fn` directly with the `FnMut` instance of +its closure (e.g. SHA-3's `sponge.xor_block_into_state`; ML-KEM +matrix/poly constructors). -/ + +private theorem from_fn_foldlM_pure_aux + {T F : Type} + (inst : CoreModels.core.ops.function.FnMut F Std.Usize T) (c : F) (f : Nat → T) + (l : List Nat) (acc : List T) + (hpure : ∀ k ∈ l, + inst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c)) : + l.foldlM + (fun (s : List T × F) (i : Nat) => do + let (v, f') ← inst.call_mut s.2 ⟨BitVec.ofNat _ i⟩ + Result.ok (s.1 ++ [v], f')) + (acc, c) = .ok (acc ++ l.map f, c) := by + induction l generalizing acc with + | nil => + simp only [List.foldlM_nil, List.map_nil, List.append_nil]; rfl + | cons h t ih => + have hh : inst.call_mut c ⟨BitVec.ofNat _ h⟩ = .ok (f h, c) := + hpure h List.mem_cons_self + have ht : ∀ k ∈ t, inst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c) := + fun k hk => hpure k (List.mem_cons_of_mem _ hk) + have hih := ih (acc ++ [f h]) ht + simp only [List.foldlM_cons, hh, bind_tc_ok, List.map_cons] + rw [hih] + simp [List.append_assoc] + +/-- Lean-level equation for `from_fn` over pure closures. -/ +theorem from_fn_pure_eq + {T F : Type} (N : Std.Usize) + (inst : CoreModels.core.ops.function.FnMut F Std.Usize T) (c : F) (f : Nat → T) + (hpure : ∀ k : Nat, k < N.val → + inst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c)) : + core.array.from_fn N inst c = + .ok ⟨(List.range N.val).map f, + by simp [List.length_map, List.length_range]⟩ := by + have hf : ∀ k ∈ List.range N.val, + inst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c) := by + intro k hk; exact hpure k (List.mem_range.mp hk) + have h_fold := + from_fn_foldlM_pure_aux inst c f (List.range N.val) [] hf + simp only [List.nil_append] at h_fold + unfold core.array.from_fn rust_primitives.slice.array_from_fn + split + · rename_i e heq + rw [h_fold] at heq; exact absurd heq (by simp) + · rename_i heq + rw [h_fold] at heq; exact absurd heq (by simp) + · rename_i result heq + rw [h_fold] at heq + have hres : result = ((List.range N.val).map f, c) := + (Result.ok.inj heq).symm + subst hres + rfl + +/-- **Generic pure-closure `[spec]` for `core.array.from_fn`.** + +For any closure whose `call_mut` is pure (doesn't mutate state), +`from_fn N inst c` succeeds and its `i`-th cell is `f i`. `hpure` is a +Triple over each `call_mut` so `hax_mvcgen` can recurse through it via +per-closure `@[spec]` lemmas. -/ +@[spec] +theorem from_fn_pure_spec + {T F : Type} [Inhabited T] (N : Std.Usize) + (inst : CoreModels.core.ops.function.FnMut F Std.Usize T) (c : F) (f : Nat → T) + (hpure : ∀ k : Nat, k < N.val → + ⦃ ⌜ True ⌝ ⦄ + inst.call_mut c ⟨BitVec.ofNat _ k⟩ + ⦃ ⇓ r => ⌜ r = (f k, c) ⌝ ⦄) : + ⦃ ⌜ True ⌝ ⦄ + core.array.from_fn N inst c + ⦃ ⇓ a => ⌜ ∀ i : Nat, i < N.val → a.val[i]! = f i ⌝ ⦄ := by + have hpure_eq : ∀ k : Nat, k < N.val → + inst.call_mut c ⟨BitVec.ofNat _ k⟩ = .ok (f k, c) := + fun k hk => result_eq_of_triple (hpure k hk) + have heq := from_fn_pure_eq N inst c f hpure_eq + rw [heq] + simp only [Triple, WP.wp] + apply SPred.pure_intro + intro i hi + show ((List.range N.val).map f)[i]! = f i + rw [List.getElem!_eq_getElem?_getD, List.getElem?_map, + List.getElem?_range hi] + rfl + +end libcrux_iot_ml_kem.Util.CreateI \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Util/LoopSpecs.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Util/LoopSpecs.lean new file mode 100644 index 00000000..ff839883 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Util/LoopSpecs.lean @@ -0,0 +1,425 @@ +/- + # `Util/LoopSpecs.lean` — Generic Aeneas `loop` + iterator-range Triples + + Generic loop combinators: + + - `result_eq_of_triple` (re-exported from `Util.SliceSpecs`). + - `IteratorRange_next_spec_i32` + `loop_range_spec_i32`. + - `IteratorRange_next_spec_usize` + `loop_range_spec_usize`. + - `array_from_fn_eq_unfold5` + the `bv_ofNat_eq_usize_lit_{0..4}` helpers. +-/ +import LibcruxIotMlKem.Util.SliceSpecs +import Hax + +open CoreModels Aeneas Aeneas.Std Result ControlFlow Std.Do + +namespace libcrux_iot_ml_kem.Util.LoopSpecs +open libcrux_iot_ml_kem.Util.SliceSpecs +set_option mvcgen.warning false +set_option linter.unusedVariables false + +/-! ## Re-export `result_eq_of_triple` + +`Util.SliceSpecs` already proves this; we re-export the name into this +module's namespace so downstream `import Util.LoopSpecs` callers +automatically see it. -/ + +-- (`libcrux_iot_ml_kem.Util.SliceSpecs.result_eq_of_triple` is in scope via `Util.SliceSpecs`.) + +/-! ## `Usize.val` conversion for small-Nat constants + +The `array.from_fn` extraction uses raw `BitVec.ofNat _ n` for iteration +indices; we need to convert these to the standard `n#usize` form. -/ + +private theorem bv_ofNat_val_eq (n : Nat) (hn : n < 2^32) : + (⟨BitVec.ofNat System.Platform.numBits n⟩ : Std.Usize).val = n := by + show (BitVec.ofNat _ n).toNat = n + simp only [BitVec.toNat_ofNat] + apply Nat.mod_eq_of_lt + have h32 : (32 : Nat) ≤ System.Platform.numBits := by + have := System.Platform.numBits_eq; omega + calc n < 2^32 := hn + _ ≤ 2^System.Platform.numBits := Nat.pow_le_pow_right (by decide) h32 + +private theorem bv_ofNat_eq_usize_lit_0 : + (⟨BitVec.ofNat _ 0⟩ : Std.Usize) = 0#usize := by + apply Std.UScalar.eq_of_val_eq; exact bv_ofNat_val_eq 0 (by omega) +private theorem bv_ofNat_eq_usize_lit_1 : + (⟨BitVec.ofNat _ 1⟩ : Std.Usize) = 1#usize := by + apply Std.UScalar.eq_of_val_eq; exact bv_ofNat_val_eq 1 (by omega) +private theorem bv_ofNat_eq_usize_lit_2 : + (⟨BitVec.ofNat _ 2⟩ : Std.Usize) = 2#usize := by + apply Std.UScalar.eq_of_val_eq; exact bv_ofNat_val_eq 2 (by omega) +private theorem bv_ofNat_eq_usize_lit_3 : + (⟨BitVec.ofNat _ 3⟩ : Std.Usize) = 3#usize := by + apply Std.UScalar.eq_of_val_eq; exact bv_ofNat_val_eq 3 (by omega) +private theorem bv_ofNat_eq_usize_lit_4 : + (⟨BitVec.ofNat _ 4⟩ : Std.Usize) = 4#usize := by + apply Std.UScalar.eq_of_val_eq; exact bv_ofNat_val_eq 4 (by omega) + +/-! ## `array.from_fn 5` unfolding lemma + +`rust_primitives.slice.array_from_fn 5#usize inst f0` unfolds to a chain +of 5 `inst.call_mut` calls building an `Array.make 5 [v0,v1,v2,v3,v4]`. -/ + +set_option maxHeartbeats 400000000 in +theorem array_from_fn_eq_unfold5 + {T F : Type} (inst : CoreModels.core.ops.function.FnMut F Std.Usize T) (f0 : F) + (v0 v1 v2 v3 v4 : T) (f1 f2 f3 f4 f5 : F) + (h0 : inst.call_mut f0 0#usize = .ok (v0, f1)) + (h1 : inst.call_mut f1 1#usize = .ok (v1, f2)) + (h2 : inst.call_mut f2 2#usize = .ok (v2, f3)) + (h3 : inst.call_mut f3 3#usize = .ok (v3, f4)) + (h4 : inst.call_mut f4 4#usize = .ok (v4, f5)) : + rust_primitives.slice.array_from_fn 5#usize inst f0 = + .ok (Std.Array.make 5#usize [v0, v1, v2, v3, v4]) := by + have h_fold : + List.foldlM + (fun (s : List T × F) (i : Nat) => do + let __discr ← inst.call_mut s.2 ⟨BitVec.ofNat _ i⟩ + match __discr with + | (v, f') => Result.ok (s.1 ++ [v], f')) + ([], f0) (List.range (5#usize).val) + = .ok ([v0, v1, v2, v3, v4], f5) := by + show List.foldlM _ ([], f0) (List.range 5) = _ + rw [show (List.range 5) = [0, 1, 2, 3, 4] from by decide] + simp only [List.foldlM_cons, List.foldlM_nil, + bv_ofNat_eq_usize_lit_0, bv_ofNat_eq_usize_lit_1, + bv_ofNat_eq_usize_lit_2, bv_ofNat_eq_usize_lit_3, + bv_ofNat_eq_usize_lit_4, h0, h1, h2, h3, h4, bind_tc_ok] + rfl + unfold rust_primitives.slice.array_from_fn + split + · rename_i e heq_match + rw [h_fold] at heq_match + exact absurd heq_match (by simp) + · rename_i heq_match + rw [h_fold] at heq_match + exact absurd heq_match (by simp) + · rename_i result heq_match + rw [h_fold] at heq_match + have hres : result = ([v0, v1, v2, v3, v4], f5) := (Result.ok.inj heq_match).symm + subst hres + rfl + +/-! ## `Usize` iterator-next spec + +The `Usize.Insts.CoreIterRangeStep` instance is an abbrev for +`core.iter.range.StepUsize` (see Aeneas-Std `FunsPrologue.lean`). + +Splits on `i.val < e.val`: if so, returns `some i` plus the incremented +range; otherwise returns `none`. -/ + +theorem IteratorRange_next_spec_usize (i e : Std.Usize) {Q} + (h_lt : (h : i.val < e.val) → + ∀ (s : Std.Usize), s.val = i.val + 1 → + (Q.1 (some i, { start := s, «end» := e })).down) + (h_ge : i.val ≥ e.val → + (Q.1 (none, { start := i, «end» := e })).down) : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + { start := i, «end» := e } + ⦃ Q ⦄ := by + rcases lt_or_ge i.val e.val with hlt | hge + · -- i < e: forward_checked succeeds with no overflow, start advances by 1. + have hUB : i.val + 1 < 2 ^ System.Platform.numBits := by + have he := e.hBounds + rcases System.Platform.numBits_eq with hN | hN <;> + simp only [Std.UScalarTy.Usize_numBits_eq, hN] at he <;> + rw [hN] <;> omega + have hno_ovf : BitVec.uaddOverflow i.bv (1#System.Platform.numBits) = false := by + have h1 : (1#System.Platform.numBits : BitVec _).toNat = 1 := by + rcases System.Platform.numBits_eq with h | h <;> rw [h] <;> rfl + simp [BitVec.uaddOverflow, h1, hUB] + have h_eq : + CoreModels.core.iter.range.IteratorRange.next + CoreModels.core.Usize.Insts.CoreIterRangeStep { start := i, «end» := e } + = .ok (CoreModels.core.option.Option.Some i, + { start := ⟨i.bv + 1#System.Platform.numBits⟩, «end» := e }) := by + unfold CoreModels.core.iter.range.IteratorRange.next + simp only [ CoreModels.core.Usize.Insts.CoreCmpPartialOrdUsize, + CoreModels.core.mkUPartialOrd, + CoreModels.core.Usize.Insts.CoreCloneClone.clone, + CoreModels.core.Usize.Insts.CoreIterRangeStep.forward_checked, + CoreModels.core.convert.TryFromUTInfallible.Blanket.try_from, + CoreModels.core.convert.From.Blanket.from, + CoreModels.core.num.Usize.checked_add, + CoreModels.core.num.Usize.overflowing_add, + CoreModels.rust_primitives.arithmetic.overflowing_add_usize, + Std.UScalar.overflowing_add] + have hcmp : compare i.val e.val = Ordering.lt := by + rw [Nat.compare_eq_lt]; exact hlt + simp [hcmp, hno_ovf] + rw [h_eq] + have h_step : (⟨i.bv + 1#System.Platform.numBits⟩ : Std.Usize).val = i.val + 1 := by + show (i.bv + 1#System.Platform.numBits).toNat = i.val + 1 + rw [BitVec.toNat_add] + have h1 : (1#System.Platform.numBits : BitVec _).toNat = 1 := by + rcases System.Platform.numBits_eq with h | h <;> rw [h] <;> rfl + rw [h1] + show (i.bv.toNat + 1) % _ = i.val + 1 + exact Nat.mod_eq_of_lt hUB + simp [Triple, WP.wp, PredTrans.apply] + exact h_lt hlt _ h_step + · -- i ≥ e: next returns none. + have h_eq : + CoreModels.core.iter.range.IteratorRange.next + CoreModels.core.Usize.Insts.CoreIterRangeStep { start := i, «end» := e } + = .ok (CoreModels.core.option.Option.None, { start := i, «end» := e }) := by + unfold CoreModels.core.iter.range.IteratorRange.next + simp only [ CoreModels.core.Usize.Insts.CoreCmpPartialOrdUsize, + CoreModels.core.mkUPartialOrd] + have hcmp : compare i.val e.val ≠ Ordering.lt := by + intro h; rw [Nat.compare_eq_lt] at h; omega + cases h : compare i.val e.val <;> simp_all + rw [h_eq] + simp [Triple, WP.wp, PredTrans.apply] + exact h_ge hge + +/-! ## `Usize` loop-over-range spec + +Specialized to `loop` over `core.ops.range.Range Usize`. An invariant +`inv : Usize → β → Result Prop` is preserved by each step. Induction on +`(e.val - start.val)`. -/ + +section loop_range_usize_helpers + +private abbrev ResultPSU := PostShape.except Error (PostShape.except PUnit PostShape.pure) + +private theorem triple_noThrow_elim_usize {α : Type} {x : Result α} {Q : α → Assertion ResultPSU} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ PostCond.noThrow Q ⦄) {v : α} (hv : x = ok v) : + (Q v).down := by + subst hv; simpa [Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h + +private theorem triple_noThrow_exists_ok_usize {α : Type} {x : Result α} + {Q : α → Assertion ResultPSU} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ PostCond.noThrow Q ⦄) : ∃ v, x = ok v := by + match x, h with + | .ok v, _ => exact ⟨v, rfl⟩ + | .fail _, h => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div, h => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +private theorem triple_of_ok_usize {α : Type} {x : Result α} {v : α} {P : α → Prop} + (hx : x = ok v) (hp : P v) : + (⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) := by + subst hx; simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +end loop_range_usize_helpers + +set_option maxHeartbeats 2000000 in +theorem loop_range_spec_usize {β : Type} + (body : (CoreModels.core.ops.range.Range Std.Usize × β) → + Result (ControlFlow (CoreModels.core.ops.range.Range Std.Usize × β) β)) + (init : β) (s e : Std.Usize) (inv : Std.Usize → β → Result Prop) + (h_le : s.val ≤ e.val) + (h_init : (inv s init).holds) + (h_step : ∀ acc (i : Std.Usize), s.val ≤ i.val → i.val ≤ e.val → + (inv i acc).holds → + ⦃ ⌜ True ⌝ ⦄ + body ({ start := i, «end» := e }, acc) + ⦃ ⇓ r => match r with + | .cont (iter', acc') => + ⌜ i.val < e.val ∧ iter'.«end» = e ∧ iter'.start.val = i.val + 1 + ∧ (inv iter'.start acc').holds ⌝ + | .done y => ⌜ (inv e y).holds ⌝ ⦄) : + ⦃ ⌜ True ⌝ ⦄ + loop body ({ start := s, «end» := e }, init) + ⦃ ⇓ r => ⌜ (inv e r).holds ⌝ ⦄ := by + suffices gen : ∀ (n : Nat) (acc : β) (start : Std.Usize), + e.val - start.val = n → + s.val ≤ start.val → start.val ≤ e.val → + (inv start acc).holds → + ⦃ ⌜ True ⌝ ⦄ loop body ({ start := start, «end» := e }, acc) + ⦃ ⇓ r => ⌜ (inv e r).holds ⌝ ⦄ by + exact gen _ init s rfl (Nat.le_refl _) h_le h_init + intro n + induction n with + | zero => + intro acc start hn hs_le hse_le hinv + have hs := h_step acc start hs_le hse_le hinv + obtain ⟨r, hbody⟩ := triple_noThrow_exists_ok_usize hs + have hpost := triple_noThrow_elim_usize hs hbody + rw [loop.eq_def, hbody] + match r with + | .cont (iter', acc') => + simp at hpost; exact absurd hpost.1 (by omega) + | .done y => + simp at hpost; exact triple_of_ok_usize rfl hpost + | succ n ih => + intro acc start hn hs_le hse_le hinv + have hs := h_step acc start hs_le hse_le hinv + obtain ⟨r, hbody⟩ := triple_noThrow_exists_ok_usize hs + have hpost := triple_noThrow_elim_usize hs hbody + rw [loop.eq_def, hbody] + match r with + | .done y => + simp at hpost; exact triple_of_ok_usize rfl hpost + | .cont (iter', acc') => + simp at hpost + obtain ⟨hlt, hend, hstart, hinv'⟩ := hpost + have hiter : iter' = { start := iter'.start, «end» := e } := by + cases iter'; cases hend; rfl + rw [hiter] + exact ih acc' iter'.start + (by rw [hstart]; omega) (by rw [hstart]; omega) (by rw [hstart]; omega) hinv' + +/-! ## `I32` iterator-next spec + +The `core.I32.Insts.CoreIterRangeStep` instance uses +`IScalar.tryMk .I32 (start.val + 1)` for `forward_checked`. For ranges +within `[-2^31, 2^31)` this always succeeds. -/ + +theorem IteratorRange_next_spec_i32 (i e : Std.I32) + (h_e_lt_max : e.val < 2^31) {Q} + (h_lt : (h : i.val < e.val) → + ∀ (s : Std.I32), s.val = i.val + 1 → + (Q.1 (some i, { start := s, «end» := e })).down) + (h_ge : i.val ≥ e.val → + (Q.1 (none, { start := i, «end» := e })).down) : + ⦃ ⌜ True ⌝ ⦄ + core.iter.range.IteratorRange.next core.I32.Insts.CoreIterRangeStep + { start := i, «end» := e } + ⦃ Q ⦄ := by + rcases lt_or_ge i.val e.val with hlt | hge + · -- i < e: forward_checked succeeds with no overflow, start advances by 1. + have hcmp : compare i.val e.val = Ordering.lt := Int.compare_eq_lt.mpr hlt + have hmin : (-2147483648 : Int) ≤ i.val := by scalar_tac + have h1val : (Std.UScalar.hcast Std.IScalarTy.I32 + (Std.UScalar.cast Std.UScalarTy.U32 1#usize)).val = 1 := by + simp only [Std.UScalar.hcast, Std.UScalar.cast, Std.IScalarTy.I32_numBits_eq, + Std.UScalarTy.U32_numBits_eq] + simp; decide + have hwval : (Std.I32.wrapping_add i + (Std.UScalar.hcast Std.IScalarTy.I32 + (Std.UScalar.cast Std.UScalarTy.U32 1#usize))).val = i.val + 1 := by + rw [Std.I32.wrapping_add_val_eq, h1val, Int.bmod_eq_emod] + simp only [Nat.reducePow] + split <;> omega + have hbmod : ((i.val : Int) + 1).bmod 4294967296 = i.val + 1 := by + rw [Int.bmod_eq_emod]; split <;> omega + have htry : CoreModels.core.U32.Insts.CoreConvertTryFromUsizeTryFromIntError.try_from 1#usize + = .ok (CoreModels.core.result.Result.Ok + (Std.UScalar.cast Std.UScalarTy.U32 1#usize)) := by + unfold CoreModels.core.U32.Insts.CoreConvertTryFromUsizeTryFromIntError.try_from + simp [Aeneas.Std.lift, Std.U32.rMax] + have h_eq : CoreModels.core.iter.range.IteratorRange.next + CoreModels.core.I32.Insts.CoreIterRangeStep { start := i, «end» := e } + = .ok (CoreModels.core.option.Option.Some i, + { start := Std.I32.wrapping_add i + (Std.UScalar.hcast Std.IScalarTy.I32 + (Std.UScalar.cast Std.UScalarTy.U32 1#usize)), + «end» := e }) := by + unfold CoreModels.core.iter.range.IteratorRange.next + simp [CoreModels.core.I32.Insts.CoreCmpPartialOrdI32, + CoreModels.core.mkIPartialOrd, + CoreModels.core.I32.Insts.CoreCloneClone.clone, + CoreModels.core.I32.Insts.CoreIterRangeStep.forward_checked, + CoreModels.core.num.I32.wrapping_add, + CoreModels.rust_primitives.arithmetic.wrapping_add_i32, + Aeneas.Std.lift, hcmp, htry, h1val, hbmod] + rw [h_eq] + simp [Triple, WP.wp, PredTrans.apply] + exact h_lt hlt _ hwval + · -- i ≥ e: next returns none. + have h_eq : CoreModels.core.iter.range.IteratorRange.next + CoreModels.core.I32.Insts.CoreIterRangeStep { start := i, «end» := e } + = .ok (CoreModels.core.option.Option.None, { start := i, «end» := e }) := by + unfold CoreModels.core.iter.range.IteratorRange.next + simp only [CoreModels.core.I32.Insts.CoreCmpPartialOrdI32, + CoreModels.core.mkIPartialOrd] + have hcmp : compare i.val e.val ≠ Ordering.lt := Int.compare_ne_lt.mpr hge + cases h : compare i.val e.val <;> simp_all + rw [h_eq] + simp [Triple, WP.wp, PredTrans.apply] + exact h_ge hge + +/-! ## `I32` loop-over-range spec + +Specialized to `loop` over `core.ops.range.Range I32`. Same shape as the +`Usize` version. -/ + +section loop_range_i32_helpers + +private abbrev ResultPS := PostShape.except Error (PostShape.except PUnit PostShape.pure) + +private theorem triple_noThrow_elim_i32 {α : Type} {x : Result α} {Q : α → Assertion ResultPS} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ PostCond.noThrow Q ⦄) {v : α} (hv : x = ok v) : + (Q v).down := by + subst hv; simpa [Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h + +private theorem triple_noThrow_exists_ok_i32 {α : Type} {x : Result α} + {Q : α → Assertion ResultPS} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ PostCond.noThrow Q ⦄) : ∃ v, x = ok v := by + match x, h with + | .ok v, _ => exact ⟨v, rfl⟩ + | .fail _, h => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div, h => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +private theorem triple_of_ok_i32 {α : Type} {x : Result α} {v : α} {P : α → Prop} + (hx : x = ok v) (hp : P v) : + (⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) := by + subst hx; simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +end loop_range_i32_helpers + +set_option maxHeartbeats 2000000 in +theorem loop_range_spec_i32 {β : Type} + (body : (CoreModels.core.ops.range.Range Std.I32 × β) → + Result (ControlFlow (CoreModels.core.ops.range.Range Std.I32 × β) β)) + (init : β) (s e : Std.I32) (inv : Std.I32 → β → Result Prop) + (h_le : s.val ≤ e.val) + (h_init : (inv s init).holds) + (h_step : ∀ acc (i : Std.I32), s.val ≤ i.val → i.val ≤ e.val → + (inv i acc).holds → + ⦃ ⌜ True ⌝ ⦄ + body ({ start := i, «end» := e }, acc) + ⦃ ⇓ r => match r with + | .cont (iter', acc') => + ⌜ i.val < e.val ∧ iter'.«end» = e ∧ iter'.start.val = i.val + 1 + ∧ (inv iter'.start acc').holds ⌝ + | .done y => ⌜ (inv e y).holds ⌝ ⦄) : + ⦃ ⌜ True ⌝ ⦄ + loop body ({ start := s, «end» := e }, init) + ⦃ ⇓ r => ⌜ (inv e r).holds ⌝ ⦄ := by + -- Generalize over the current iteration's `start` and induct on the number + -- of remaining steps `n = (e.val - start.val).toNat`. + suffices gen : ∀ (n : Nat) (acc : β) (start : Std.I32), + (e.val - start.val).toNat = n → + s.val ≤ start.val → start.val ≤ e.val → + (inv start acc).holds → + ⦃ ⌜ True ⌝ ⦄ loop body ({ start := start, «end» := e }, acc) + ⦃ ⇓ r => ⌜ (inv e r).holds ⌝ ⦄ by + exact gen _ init s rfl (Int.le_refl _) h_le h_init + intro n + induction n with + | zero => + intro acc start hn hs_le hse_le hinv + have hs := h_step acc start hs_le hse_le hinv + obtain ⟨r, hbody⟩ := triple_noThrow_exists_ok_i32 hs + have hpost := triple_noThrow_elim_i32 hs hbody + rw [loop.eq_def, hbody] + match r with + | .cont (iter', acc') => + simp at hpost; exact absurd hpost.1 (by omega) + | .done y => + simp at hpost; exact triple_of_ok_i32 rfl hpost + | succ n ih => + intro acc start hn hs_le hse_le hinv + have hs := h_step acc start hs_le hse_le hinv + obtain ⟨r, hbody⟩ := triple_noThrow_exists_ok_i32 hs + have hpost := triple_noThrow_elim_i32 hs hbody + rw [loop.eq_def, hbody] + match r with + | .done y => + simp at hpost; exact triple_of_ok_i32 rfl hpost + | .cont (iter', acc') => + simp at hpost + obtain ⟨hlt, hend, hstart, hinv'⟩ := hpost + have hiter : iter' = { start := iter'.start, «end» := e } := by + cases iter'; cases hend; rfl + rw [hiter] + exact ih acc' iter'.start + (by rw [hstart]; omega) (by rw [hstart]; omega) (by rw [hstart]; omega) hinv' + +end libcrux_iot_ml_kem.Util.LoopSpecs \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Util/SliceSpecs.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Util/SliceSpecs.lean new file mode 100644 index 00000000..e66dca06 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Util/SliceSpecs.lean @@ -0,0 +1,442 @@ +/- + # `Util/SliceSpecs.lean` — Aeneas Std byte / slice `@[spec]` bridges + + Verbatim lift from libcrux-iot SHA-3's + `LibcruxIotSha3/Sponge/SliceSpecs.lean`, with namespace rewritten + from `libcrux_iot_sha3.Sponge` → `libcrux_iot_ml_kem.Util` and the + SHA-3-specific `attribute [local irreducible] keccak.…` / `open … + libcrux_iot_sha3 hacspec_sha3` lines dropped. + + Installs reusable `@[spec]` Triples for the generic Aeneas/`core_models` + byte/slice operations that every ML-KEM byte encode/decode and matrix + routine routes through. The `result_eq_of_triple` helper that the + original file pulled from `LibcruxIotSha3.Equivalence.I32LoopSpec` is + re-defined here as a private theorem (it is also re-exported by + `Util/LoopSpecs.lean` for cross-Util use). + + ## Installed (10 Triples) + + - `core_models_slice_Slice_len_spec` — `core.slice.Slice.len` + returns the underlying list length. + - `massert_spec` — `Aeneas.Std.massert b` succeeds (with `()`) when `b`. + - `core_models_num_U32_from_le_bytes_spec`, + `core_models_num_U32_to_le_bytes_spec` — byte ↔ u32 LE. + - `core_models_num_U64_from_le_bytes_spec`, + `core_models_num_U64_to_le_bytes_spec` — byte ↔ u64 LE. + - `core_models_Slice_Insts_index_RangeUsize_spec` — slice subindexing + over `Range`. + - `core_models_Slice_Insts_index_mut_RangeUsize_spec` — mutable slice + subindexing over `Range`. + - `core_models_result_Result_unwrap_spec` — `result.Result.unwrap` on + `.Ok v` yields `v`. + - `core_models_slice_Slice_copy_from_slice_spec` — write-into-slice; + the impl model returns the source slice outright when lengths match. + - `core_models_array_try_from_slice_spec` — `Slice T → Array T N` + coercion; total when `s.val.length = N.val`. + - `core_models_try_from_unwrap_spec` — fused + `try_from + Result.unwrap` Triple. +-/ +import LibcruxIotMlKem.Extraction.Funs + +open CoreModels Aeneas +open Aeneas.Std hiding namespace core alloc +open Result Std.Do + +namespace libcrux_iot_ml_kem.Util.SliceSpecs +set_option mvcgen.warning false +set_option linter.unusedVariables false + +/-! ## Local helper: Triple → Result-equation converter + +When each `call_mut`'s purity is stated as a Triple (natural for +`hax_mvcgen`-driven proofs), the Result equation needed by the +`try_from`/`createi`/`from_fn` pure-closure pattern follows +directly. This file uses it once (in `core_models_try_from_unwrap_spec`); +`Util/LoopSpecs.lean` and `Util/CreateI.lean` use the same helper. -/ + +theorem result_eq_of_triple {α : Type} {x : Result α} {v : α} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ r = v ⌝ ⦄) : x = .ok v := by + match hx : x, h with + | .ok v', h => + have hv' : v' = v := by + simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply] at h + exact h + rw [hv'] + | .fail e, h => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div, h => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-! ## Aeneas Std byte/slice `@[spec]` lemmas. -/ + +/-! ### `core.slice.Slice.len` -/ + +/-- The hax `core.slice.Slice.len` is a thin `pure`-wrapper around + `Aeneas.Std.Slice.len`. Always succeeds with the underlying list length + (as a `Usize`). -/ +@[spec] +theorem core_models_slice_Slice_len_spec {T : Type} (s : Slice T) : + ⦃ ⌜ True ⌝ ⦄ + core.slice.Slice.len s + ⦃ ⇓ r => ⌜ r.val = s.val.length ⌝ ⦄ := by + simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, core.slice.Slice.len] + +/-! ### `Aeneas.Std.massert` -/ + +/-- `massert b` succeeds with `()` iff `b` holds. -/ +@[spec] +theorem massert_spec (b : Prop) [Decidable b] (h : b) : + ⦃ ⌜ True ⌝ ⦄ + massert b + ⦃ ⇓ r => ⌜ r = () ⌝ ⦄ := by + simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, h] + +/-! ### `core.num.U32.from_le_bytes` / `U32.to_le_bytes` -/ + +/-- The four-byte LE-load `U32.from_le_bytes` always succeeds with + `core.num.U32.from_le_bytes` applied to the input array. -/ +@[spec] +theorem core_models_num_U32_from_le_bytes_spec (bytes : Std.Array Std.U8 4#usize) : + ⦃ ⌜ True ⌝ ⦄ + core.num.U32.from_le_bytes bytes + ⦃ ⇓ r => ⌜ r = Std.core.num.U32.from_le_bytes bytes ⌝ ⦄ := by + simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, + core.num.U32.from_le_bytes, rust_primitives.arithmetic.from_le_bytes_u32] + +/-- The four-byte LE-store `U32.to_le_bytes` always succeeds with + `core.num.U32.to_le_bytes` applied to the input integer. -/ +@[spec] +theorem core_models_num_U32_to_le_bytes_spec (x : Std.U32) : + ⦃ ⌜ True ⌝ ⦄ + core.num.U32.to_le_bytes x + ⦃ ⇓ r => ⌜ r = Std.core.num.U32.to_le_bytes x ⌝ ⦄ := by + simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, + core.num.U32.to_le_bytes, rust_primitives.arithmetic.to_le_bytes_u32] + +/-! ### `core.num.U64.from_le_bytes` / `U64.to_le_bytes` -/ + +/-- The eight-byte LE-load `U64.from_le_bytes` always succeeds with + `core.num.U64.from_le_bytes` applied to the input array. -/ +@[spec] +theorem core_models_num_U64_from_le_bytes_spec (bytes : Std.Array Std.U8 8#usize) : + ⦃ ⌜ True ⌝ ⦄ + core.num.U64.from_le_bytes bytes + ⦃ ⇓ r => ⌜ r = Std.core.num.U64.from_le_bytes bytes ⌝ ⦄ := by + simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, + core.num.U64.from_le_bytes, rust_primitives.arithmetic.from_le_bytes_u64] + +/-- The eight-byte LE-store `U64.to_le_bytes` always succeeds with + `core.num.U64.to_le_bytes` applied to the input integer. -/ +@[spec] +theorem core_models_num_U64_to_le_bytes_spec (x : Std.U64) : + ⦃ ⌜ True ⌝ ⦄ + core.num.U64.to_le_bytes x + ⦃ ⇓ r => ⌜ r = Std.core.num.U64.to_le_bytes x ⌝ ⦄ := by + simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, + core.num.U64.to_le_bytes, rust_primitives.arithmetic.to_le_bytes_u64] + +/-! ### `core.Slice.Insts.CoreOpsIndexIndex.index` over `Range Usize` -/ + +/-- Slice subindexing over a `Range` succeeds whenever the range is + in bounds, returning the sub-`Slice` whose `val` is the contiguous + slice `s.val[start..end]`. -/ +@[spec] +theorem core_models_Slice_Insts_index_RangeUsize_spec + {T : Type} (s : Slice T) (r : core.ops.range.Range Std.Usize) + (h0 : r.start.val ≤ r.end.val) (h1 : r.end.val ≤ s.val.length) : + ⦃ ⌜ True ⌝ ⦄ + core.Slice.Insts.CoreOpsIndexIndex.index + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice T) s r + ⦃ ⇓ r' => ⌜ r'.val = s.val.slice r.start.val r.end.val ∧ + r'.val.length = r.end.val - r.start.val ⌝ ⦄ := by + unfold core.Slice.Insts.CoreOpsIndexIndex.index + core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + Aeneas.Std.core.slice.index.Slice.index + Aeneas.Std.core.slice.index.SliceIndexRangeUsizeSlice.index + have h0' : (⟨r.start, r.end⟩ : core.ops.range.Range Std.Usize).start + ≤ (⟨r.start, r.end⟩ : core.ops.range.Range Std.Usize).end := by + simpa [UScalar.le_equiv] using h0 + have h1' : (⟨r.start, r.end⟩ : core.ops.range.Range Std.Usize).end.val ≤ (Slice.length s) := by + simpa [Slice.length] using h1 + simp only [Triple, WP.wp] + simp [h0', h1', Slice.length] + simp [List.slice] + simp [PredTrans.apply] + omega + +/-! ### `core.result.Result.unwrap` + +The hax `Result.unwrap` on the `core_models` `result.Result` enum panics on +`Err` and returns the inner `T` on `Ok`. We give a Triple-style spec under +the precondition `r = .Ok v`. -/ + +/-- `Result.unwrap` of a `.Ok`-valued `r` returns the inner value. + + We state both the precondition (`∃ v, r = .Ok v`) and the post + (`r = .Ok r'`), leaving `v` quantified inside `mvcgen`'s assertion + bag. This avoids the mvcgen unification quirk where the explicit + `v` argument gets eagerly bound to the first matching local of the + right type. -/ +@[spec] +theorem core_models_result_Result_unwrap_spec + {T E : Type} (dbg : core.fmt.Debug E) + (r : core.result.Result T E) + (h : ∃ v, r = .Ok v) : + ⦃ ⌜ True ⌝ ⦄ + core.result.Result.unwrap dbg r + ⦃ ⇓ r' => ⌜ r = .Ok r' ⌝ ⦄ := by + obtain ⟨v, hv⟩ := h + unfold core.result.Result.unwrap + subst hv + simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply] + + +/-! ### `core.Slice.Insts.CoreOpsIndexIndexMut.index_mut` over `Range Usize` + +Used by Aeneas-extracted loops that obtain a mutable sub-slice and a +write-back closure (e.g. SHA-3's `state.store_block_2u32_loop.body`, and +the ML-KEM byte-encode loops). -/ + +/-- Mutable slice subindexing over a `Range` returns both the + sub-slice (same `val` as the non-mut `index`) and a write-back + closure that overwrites `s.val[r.start.val..]` with the argument's + `val`. -/ +@[spec] +theorem core_models_Slice_Insts_index_mut_RangeUsize_spec + {T : Type} (s : Slice T) (r : core.ops.range.Range Std.Usize) + (h0 : r.start.val ≤ r.end.val) (h1 : r.end.val ≤ s.val.length) : + ⦃ ⌜ True ⌝ ⦄ + core.Slice.Insts.CoreOpsIndexIndexMut.index_mut + (core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice T) s r + ⦃ ⇓ p => ⌜ p.1.val = s.val.slice r.start.val r.end.val ∧ + p.1.val.length = r.end.val - r.start.val ∧ + ∀ s', (p.2 s').val = s.val.setSlice! r.start.val s'.val ⌝ ⦄ := by + unfold core.Slice.Insts.CoreOpsIndexIndexMut.index_mut + core.ops.range.RangeUsize.Insts.CoreSliceIndexSliceIndexSliceSlice + Aeneas.Std.core.slice.index.Slice.index_mut + Aeneas.Std.core.slice.index.SliceIndexRangeUsizeSlice.index_mut + have h0' : (⟨r.start, r.end⟩ : core.ops.range.Range Std.Usize).start + ≤ (⟨r.start, r.end⟩ : core.ops.range.Range Std.Usize).end := by + simpa [UScalar.le_equiv] using h0 + have h1' : (⟨r.start, r.end⟩ : core.ops.range.Range Std.Usize).end.val ≤ (Slice.length s) := by + simpa [Slice.length] using h1 + simp only [Triple, WP.wp] + simp [h0', h1', Slice.length] + simp [List.slice] + simp [PredTrans.apply] + omega + +/-! ### `core.slice.Slice.copy_from_slice` -/ + +/-- `copy_from_slice dst src` succeeds with the source slice `src` + whenever both slices have the same length (the impl model returns + `src` outright when lengths match). -/ +@[spec] +theorem core_models_slice_Slice_copy_from_slice_spec + {T : Type} (cpy : core.marker.Copy T) (dst src : Slice T) + (h : dst.val.length = src.val.length) : + ⦃ ⌜ True ⌝ ⦄ + core.slice.Slice.copy_from_slice cpy dst src + ⦃ ⇓ r => ⌜ r = src ⌝ ⦄ := by + unfold core.slice.Slice.copy_from_slice + have h' : dst.len = src.len := by + apply Std.UScalar.eq_of_val_eq + simp [h] + simp [Triple, WP.wp, h', PostCond.noThrow, PredTrans.apply] + +/-! ### `core.Array.Insts.CoreConvertTryFromShared0SliceTryFromSliceError.try_from` + +The body invokes `rust_primitives.slice.array_from_fn` on the `try_from` +closure (whose state is just the source `Slice T`). The proof has three +parts: + +1. Closure step: `call_mut s i = .ok (s.val[i.val]!, s)` for `i.val < + s.length` — the closure reads the slice and preserves its state. +2. `foldlM` invariant: induction on `k` shows that folding the closure + over `List.range' 0 k` (starting from `([], s)`) returns + `(s.val.take k, s)` when `k ≤ s.length`. +3. Final assembly: `array_from_fn N closure s = .ok (Array.make N s.val)` + when `s.length = N.val`, hence `try_from N inst s = .ok (.Ok a)` with + `a.val = s.val`. -/ + +/-- Numeric helper: `(⟨BitVec.ofNat _ n⟩ : Usize).val = n` when + `n ≤ Std.Usize.max` (equivalently, `n < 2^Usize.numBits`). + We state the bit-vector size as `UScalarTy.Usize.numBits` since this + is the form `Usize.val` unfolds to. -/ +private theorem bv_ofNat_usize_val_eq (n : Nat) (hn : n ≤ Std.Usize.max) : + (⟨BitVec.ofNat Std.UScalarTy.Usize.numBits n⟩ : Std.Usize).val = n := by + show (BitVec.ofNat _ n).toNat = n + simp only [BitVec.toNat_ofNat] + apply Nat.mod_eq_of_lt + -- `Std.Usize.max = 2^Std.Usize.numBits - 1`, hence `n ≤ max` means `n < 2^numBits`. + have hmax : Std.Usize.max + 1 = 2 ^ Std.UScalarTy.Usize.numBits := by + simp [Std.Usize.max, Std.Usize.numBits] + omega + +/-- Closure step lemma. The `try_from` closure's state is the source + `Slice T`; `call_mut` reads the `i`-th element and preserves state. -/ +private theorem try_from_closure_call_mut_eq + {T : Type} [Inhabited T] {N : Std.Usize} (cpy : core.marker.Copy T) + (s : Slice T) (i : Std.Usize) (h : i.val < s.val.length) : + core.convert.TryFromArrayShared0SliceTryFromSliceError.try_from.closure.Insts.CoreOpsFunctionFnMutTupleUsizeT.call_mut + (T := T) (N := N) cpy s i = + .ok (s.val[i.val]!, s) := by + -- Reduces to `do let t ← slice_index s i; ok (t, s)`. + unfold core.convert.TryFromArrayShared0SliceTryFromSliceError.try_from.closure.Insts.CoreOpsFunctionFnMutTupleUsizeT.call_mut + unfold rust_primitives.slice.slice_index Std.Slice.index_usize + -- Now `s[i]?` matches; for `i.val < s.length`, `s[i]? = some s.val[i.val]!`. + have hsome : s[i]? = some s.val[i.val]! := by + simp only [Std.Slice.getElem?_Usize_eq] + rw [List.getElem?_eq_getElem h, List.getElem!_eq_getElem?_getD, + List.getElem?_eq_getElem h] + rfl + rw [hsome] + rfl + +/-- The closure-fold accumulator at step `k` is `s.val.take k`. We prove + a slightly stronger invariant: starting from any accumulator `acc` + with the closure state `s`, folding over `List.range' acc.length k` + yields `(acc ++ s.val.slice acc.length (acc.length + k), s)` when + `acc.length + k ≤ s.length` and acc lines up with the slice prefix. -/ +private theorem foldlM_try_from_closure_invariant + {T : Type} [Inhabited T] {N : Std.Usize} (cpy : core.marker.Copy T) + (s : Slice T) + (_hN : s.val.length ≤ Std.Usize.max) : + ∀ (k start : Nat) (acc : List T), + acc = s.val.take start → + start + k ≤ s.val.length → + start + k ≤ Std.Usize.max → + (List.range' start k).foldlM + (fun (p : List T × Slice T) (i : Nat) => do + let (v, f') ← + core.convert.TryFromArrayShared0SliceTryFromSliceError.try_from.closure.Insts.CoreOpsFunctionFnMutTupleUsizeT.call_mut + (T := T) (N := N) cpy p.2 ⟨BitVec.ofNat _ i⟩ + ok (p.1 ++ [v], f')) + (acc, s) + = .ok (s.val.take (start + k), s) := by + intro k + induction k with + | zero => + intro start acc hacc hk1 hk2 + show List.foldlM _ (acc, s) (List.range' start 0) = _ + rw [show List.range' start 0 = [] from rfl] + rw [List.foldlM_nil] + show Result.ok (acc, s) = Result.ok (s.val.take (start + 0), s) + rw [hacc, Nat.add_zero] + | succ k ih => + intro start acc hacc hk1 hk2 + -- `List.range' start (k+1) = start :: List.range' (start+1) k` + rw [show List.range' start (k + 1) = start :: List.range' (start + 1) k from rfl] + simp only [List.foldlM_cons] + -- The step at `start` calls `call_mut s ⟨BitVec.ofNat _ start⟩`. + have hstart_lt : start < s.val.length := by omega + have hstart_max : start ≤ Std.Usize.max := by omega + have hval : (⟨BitVec.ofNat Std.UScalarTy.Usize.numBits start⟩ : Std.Usize).val = start := + bv_ofNat_usize_val_eq start hstart_max + have hcall := try_from_closure_call_mut_eq (T := T) (N := N) cpy s + ⟨BitVec.ofNat _ start⟩ (by rw [hval]; exact hstart_lt) + -- Rewrite both the closure-call output's `.val` and the `i` arg uniformly. + rw [hval] at hcall + rw [hcall] + simp only [bind_tc_ok] + -- New accumulator is `acc ++ [s.val[start]!] = s.val.take (start + 1)`. + have hacc' : acc ++ [s.val[start]!] = s.val.take (start + 1) := by + rw [hacc] + have : start < s.val.length := hstart_lt + rw [List.take_add_one] + simp [List.getElem?_eq_getElem this, List.getElem!_eq_getElem?_getD] + -- Now apply IH at `start := start + 1`. Note `(start + 1) + k = start + (k + 1)`. + have ih' := ih (start + 1) (acc ++ [s.val[start]!]) hacc' (by omega) (by omega) + have h_assoc : (start + 1) + k = start + (k + 1) := by omega + rw [h_assoc] at ih' + exact ih' + +/-- `array_from_fn N (try_from closure) s = .ok (Array.make N s.val)` + when `s.length = N.val`. -/ +private theorem array_from_fn_try_from_eq_ok + {T : Type} [Inhabited T] {N : Std.Usize} (cpy : core.marker.Copy T) + (s : Slice T) (hlen : s.val.length = N.val) : + rust_primitives.slice.array_from_fn N + (core.convert.TryFromArrayShared0SliceTryFromSliceError.try_from.closure.Insts.CoreOpsFunctionFnMutTupleUsizeT + (T := T) (N := N) cpy) s + = .ok (Std.Array.make N s.val (by simp [hlen])) := by + -- Foldl invariant at start=0, k=N.val, acc=[]. + have hN_max : s.val.length ≤ Std.Usize.max := by + have := s.property; exact this + have hN_max' : N.val ≤ Std.Usize.max := by + rw [← hlen]; exact hN_max + have h_fold := + foldlM_try_from_closure_invariant (T := T) (N := N) cpy s hN_max + N.val 0 [] (by simp) (by omega) (by omega) + -- Normalize `0 + N.val = N.val` and reduce `take N.val s.val = s.val`. + simp only [Nat.zero_add] at h_fold + have h_take : s.val.take N.val = s.val := + List.take_of_length_le (by omega) + rw [h_take] at h_fold + -- Match `range N.val` with `range' 0 N.val` (`range` is defined as `range' 0 _`). + have hrange : (List.range N.val) = List.range' 0 N.val := List.range_eq_range' + -- The `array_from_fn` definition is a `match` on the foldlM result. + rw [← hrange] at h_fold + unfold rust_primitives.slice.array_from_fn + -- Now transport the foldlM equation through the `split`. + split + · rename_i e heq + rw [h_fold] at heq; exact absurd heq (by simp) + · rename_i heq + rw [h_fold] at heq; exact absurd heq (by simp) + · rename_i result heq + rw [h_fold] at heq + have hres : result = (s.val, s) := (Result.ok.inj heq).symm + subst hres + rfl + +/-- The main Triple: `try_from N cpy s` succeeds with `Ok (Array.make N s.val _)`, + whenever `s.val.length = N.val`. -/ +@[spec] +theorem core_models_array_try_from_slice_spec + {T : Type} [Inhabited T] {N : Std.Usize} (cpy : core.marker.Copy T) + (s : Slice T) (hlen : s.val.length = N.val) : + ⦃ ⌜ True ⌝ ⦄ + core.Array.Insts.CoreConvertTryFromShared0SliceTryFromSliceError.try_from + N cpy s + ⦃ ⇓ r => ⌜ r = core.result.Result.Ok + (Std.Array.make N s.val (by simp [hlen])) ⌝ ⦄ := by + -- Unfold try_from and reduce the `do` chain step-by-step. + unfold core.Array.Insts.CoreConvertTryFromShared0SliceTryFromSliceError.try_from + -- `core.slice.Slice.len x` is `pure (Slice.len x)`, returns `.ok (Slice.len s)`. + unfold core.slice.Slice.len + -- The if-decision: `Slice.len s = N` reduces to `s.val.length = N.val`. + have hi_eq : (Std.Slice.len s) = N := by + apply Std.UScalar.eq_of_val_eq + simp [hlen] + -- Reduce the array_from_fn call to .ok. + have h_afn := array_from_fn_try_from_eq_ok (T := T) (N := N) cpy s hlen + simp only [Triple, WP.wp, pure, Pure.pure, bind_tc_ok, hi_eq, if_true, h_afn] + intro _ + trivial + +/-- Fused `try_from + Result.unwrap` Triple. The two-step pattern + `let r ← try_from N cpy s; let a ← Result.unwrap dbg r` is the + canonical Aeneas idiom for slice → array coercion; we provide a + direct equation that mvcgen can chain without intermediate metavars. -/ +theorem core_models_try_from_unwrap_spec + {T : Type} [Inhabited T] {N : Std.Usize} (cpy : core.marker.Copy T) + (dbg : core.fmt.Debug core.array.TryFromSliceError) + (s : Slice T) (hlen : s.val.length = N.val) : + ⦃ ⌜ True ⌝ ⦄ + (do + let r ← core.Array.Insts.CoreConvertTryFromShared0SliceTryFromSliceError.try_from + N cpy s + core.result.Result.unwrap dbg r) + ⦃ ⇓ a => ⌜ a = Std.Array.make N s.val (by simp [hlen]) ⌝ ⦄ := by + -- Establish `try_from ... = .ok (.Ok (Array.make N s.val _))` outright. + have h_try := core_models_array_try_from_slice_spec (T := T) (N := N) cpy s hlen + -- Then unfold Result.unwrap and reduce. + unfold core.result.Result.unwrap + -- Reduce `try_from` to its known .ok form via the local Triple → eq helper. + have h_eq : (core.Array.Insts.CoreConvertTryFromShared0SliceTryFromSliceError.try_from + N cpy s) + = .ok (.Ok (Std.Array.make N s.val (by simp [hlen]))) := + result_eq_of_triple h_try + rw [h_eq] + simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply] + +end libcrux_iot_ml_kem.Util.SliceSpecs \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/BvMasks.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/BvMasks.lean new file mode 100644 index 00000000..91fd26ff --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/BvMasks.lean @@ -0,0 +1,22 @@ +/- + # `Vector/Portable/Arithmetic/BvMasks.lean` — bitvector-mask identities + used by `arithmetic.rs` Triples. + + Currently exports a single identity needed by + `get_n_least_significant_bits_spec`. +-/ +import Mathlib.Tactic.IntervalCases + +namespace libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks +/-- The 32-bit BV mask `(1 <<< n) - 1` has `.toNat = 2^n - 1` for any + `n ≤ 16` (and in fact for `n < 32`, but 16 is what L0.1 needs). + + Proof: enumerate the 17 cases `n ∈ {0, …, 16}` and discharge each + by `decide` on the closed BV expression. Mediates the + `interval_cases` use that was previously inline in + `Equivalence/L0_FieldArith.lean`. -/ +theorem mask_pow2_minus_one_toNat (n : Nat) (h : n ≤ 16) : + ((1#32 <<< n) - 1#32).toNat = 2 ^ n - 1 := by + interval_cases n <;> decide + +end libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/Element.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/Element.lean new file mode 100644 index 00000000..2f9720e9 --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/Element.lean @@ -0,0 +1,3180 @@ +/- + # `Equivalence/L1_VectorElementOps.lean` — Layer 1 elementwise PortableVector Triples. + + Layer-1 Triples for the per-element-of-PortableVector ops. Each L1.x + Triple proves that running the loop produces an output vector where + every element satisfies the corresponding L0.x per-element post: + + - **L1.3 `barrett_reduce_spec`** — instantiates `elementwise_unary_spec` + with `barrett_reduce_element_spec` from L0.2. + - **L1.4 `montgomery_multiply_by_constant_spec`** — instantiates + `elementwise_unary_spec` with `montgomery_multiply_fe_by_fer_spec` + (L0.4), post-processed to expose `(r * 2^16) % 3329 = (x * c) % 3329`. + - **L1.5 `cond_subtract_3329_spec`** — uses a conditional-body + variant `elementwise_cond_unary_spec` (the else-branch returns + the input vec unchanged, so the canonical `unary_loop_body` + shape doesn't apply). + - **L1.6 `negate_spec`** — instantiates `elementwise_unary_spec` + with `core.num.I16.wrapping_neg` (per-element `.bv = -x.bv`). + + L1.1, L1.2, L1.7-L1.10 will follow the same pattern (instantiate the + per-element L0.x Triple via `elementwise_unary_spec` / its + conditional / binary variants). +-/ +import LibcruxIotMlKem.Vector.Portable.Arithmetic.LoopHelper +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element +open libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper + +/-! ## L1.1 — `add_spec` + + The Vector.Portable.Arithmetic.add impl is a 16-iter loop that + calls `core.num.I16.wrapping_add lhs[i] rhs[i]` and writes + the result back to `lhs[i]`. Under the per-element no-overflow + bound `|lhs.val + rhs.val| ≤ 2^15 - 1`, the wrap is the identity + and `(wrapping_add lhs rhs).val = lhs.val + rhs.val`. -/ + +/-- Per-element predicate (guarded form): given the no-overflow bound + on `x + y`, the output value equals the sum and is in range. -/ +private def add_per_elem_P (x y r : Std.I16) : Prop := + ((x.val + y.val : Int)).natAbs ≤ 2 ^ 15 - 1 → + r.val = x.val + y.val ∧ r.val.natAbs ≤ 2 ^ 15 - 1 + +/-- Per-element Triple: `core.num.I16.wrapping_add x y` reduces + to `.ok (Std.I16.wrapping_add x y)`, whose `.val` is the bmod of + `x.val + y.val` mod `2^16`. Under the no-overflow bound, + `Int.bmod` is the identity. -/ +private theorem add_per_elem_spec (x y : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.num.I16.wrapping_add x y + ⦃ ⇓ r => ⌜ add_per_elem_P x y r ⌝ ⦄ := by + have h_ok : + CoreModels.core.num.I16.wrapping_add x y + = .ok (Aeneas.Std.I16.wrapping_add x y) := by + unfold CoreModels.core.num.I16.wrapping_add + unfold rust_primitives.arithmetic.wrapping_add_i16 + rfl + rw [h_ok] + simp only [Std.Do.Triple, WP.wp] + intro _ + show add_per_elem_P x y (Aeneas.Std.I16.wrapping_add x y) + unfold add_per_elem_P + intro hb + have h_val := Aeneas.Std.I16.wrapping_add_val_eq x y + have h_lb : -(2 ^ 15 : Int) ≤ x.val + y.val := by + have h_abs : ((x.val + y.val : Int)).natAbs ≤ 2 ^ 15 - 1 := hb + omega + have h_ub : x.val + y.val < (2 ^ 15 : Int) := by + have h_abs : ((x.val + y.val : Int)).natAbs ≤ 2 ^ 15 - 1 := hb + omega + have h_bmod : Int.bmod (x.val + y.val) (2 ^ 16) = x.val + y.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + refine ⟨?_, ?_⟩ + · rw [h_val, h_bmod] + · rw [h_val, h_bmod]; exact hb + +@[spec] +theorem add_spec + (lhs rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hpre : ∀ i : Nat, i < 16 → + ((lhs.elements.val[i]!).val + (rhs.elements.val[i]!).val : Int).natAbs ≤ 2 ^ 15 - 1) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.add lhs rhs + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).val + = (lhs.elements.val[i]!).val + (rhs.elements.val[i]!).val + ∧ (r.elements.val[i]!).val.natAbs ≤ 2 ^ 15 - 1 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.add + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.add_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + -- Bridge `add_loop.body rhs` to + -- `binary_loop_body CoreModels.core.num.I16.wrapping_add rhs`. + have h_body_eq : + (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + libcrux_iot_ml_kem.vector.portable.arithmetic.add_loop.body rhs p.1 p.2) + = (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + binary_loop_body CoreModels.core.num.I16.wrapping_add rhs p.1 p.2) := by + funext p + rcases p with ⟨iter1, vec1⟩ + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.add_loop.body + unfold binary_loop_body + rfl + rw [h_body_eq] + apply Std.Do.Triple.of_entails_right _ + (elementwise_binary_spec + CoreModels.core.num.I16.wrapping_add + add_per_elem_P + add_per_elem_spec + lhs rhs) + rw [PostCond.entails_noThrow] + intro r hh j hj + obtain ⟨rj, _h_eq, h_acc, h_P⟩ := hh j hj + rw [h_acc] + exact h_P (hpre j hj) + +/-! ## L1.2 — `sub_spec` + + The Vector.Portable.Arithmetic.sub impl is a 16-iter loop that + calls `core.num.I16.wrapping_sub lhs[i] rhs[i]`. Same + structure as `add_spec` but with `-` instead of `+`. -/ + +/-- Per-element predicate (guarded form): given the no-overflow bound + on `x - y`, the output value equals the difference and is in range. -/ +private def sub_per_elem_P (x y r : Std.I16) : Prop := + ((x.val - y.val : Int)).natAbs ≤ 2 ^ 15 - 1 → + r.val = x.val - y.val ∧ r.val.natAbs ≤ 2 ^ 15 - 1 + +/-- Per-element Triple: `core.num.I16.wrapping_sub x y` reduces + to `.ok (Std.I16.wrapping_sub x y)`, whose `.val` is the bmod of + `x.val - y.val` mod `2^16`. Under the no-overflow bound, + `Int.bmod` is the identity. -/ +private theorem sub_per_elem_spec (x y : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.num.I16.wrapping_sub x y + ⦃ ⇓ r => ⌜ sub_per_elem_P x y r ⌝ ⦄ := by + have h_ok : + CoreModels.core.num.I16.wrapping_sub x y + = .ok (Aeneas.Std.I16.wrapping_sub x y) := by + unfold CoreModels.core.num.I16.wrapping_sub + unfold rust_primitives.arithmetic.wrapping_sub_i16 + rfl + rw [h_ok] + simp only [Std.Do.Triple, WP.wp] + intro _ + show sub_per_elem_P x y (Aeneas.Std.I16.wrapping_sub x y) + unfold sub_per_elem_P + intro hb + have h_val := Aeneas.Std.I16.wrapping_sub_val_eq x y + have h_lb : -(2 ^ 15 : Int) ≤ x.val - y.val := by + have h_abs : ((x.val - y.val : Int)).natAbs ≤ 2 ^ 15 - 1 := hb + omega + have h_ub : x.val - y.val < (2 ^ 15 : Int) := by + have h_abs : ((x.val - y.val : Int)).natAbs ≤ 2 ^ 15 - 1 := hb + omega + have h_bmod : Int.bmod (x.val - y.val) (2 ^ 16) = x.val - y.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + refine ⟨?_, ?_⟩ + · rw [h_val, h_bmod] + · rw [h_val, h_bmod]; exact hb + +@[spec] +theorem sub_spec + (lhs rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hpre : ∀ i : Nat, i < 16 → + ((lhs.elements.val[i]!).val - (rhs.elements.val[i]!).val : Int).natAbs ≤ 2 ^ 15 - 1) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.sub lhs rhs + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).val + = (lhs.elements.val[i]!).val - (rhs.elements.val[i]!).val + ∧ (r.elements.val[i]!).val.natAbs ≤ 2 ^ 15 - 1 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.sub + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.sub_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + have h_body_eq : + (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + libcrux_iot_ml_kem.vector.portable.arithmetic.sub_loop.body rhs p.1 p.2) + = (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + binary_loop_body CoreModels.core.num.I16.wrapping_sub rhs p.1 p.2) := by + funext p + rcases p with ⟨iter1, vec1⟩ + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.sub_loop.body + unfold binary_loop_body + rfl + rw [h_body_eq] + apply Std.Do.Triple.of_entails_right _ + (elementwise_binary_spec + CoreModels.core.num.I16.wrapping_sub + sub_per_elem_P + sub_per_elem_spec + lhs rhs) + rw [PostCond.entails_noThrow] + intro r hh j hj + obtain ⟨rj, _h_eq, h_acc, h_P⟩ := hh j hj + rw [h_acc] + exact h_P (hpre j hj) + +/-! ## L1.3 — `barrett_reduce_spec` + + Implements the upstream `Vector.Portable.Arithmetic.barrett_reduce` + correctness theorem. The impl is a 16-iteration + `for i in 0..16` loop that calls L0.2 `barrett_reduce_element` on each + element. The post asserts each output element is congruent to its + input mod 3329 and bounded in absolute value by 3328. + + F* pre: `∀ i < 16, is_i16b 32767 vec.elements[i]` + F* post: `∀ i < 16, is_i16b 3328 r.elements[i] + ∧ v r.elements[i] % 3329 = v vec.elements[i] % 3329` -/ + +/-- Per-element predicate threading the L0.2 bound precondition into + an implication. -/ +private def barrett_per_elem_P (x y : Std.I16) : Prop := + x.val.natAbs ≤ 32767 → + libcrux_iot_ml_kem.Spec.ModularArith.modq_eq y.val x.val 3329 + ∧ y.val.natAbs ≤ 3328 + +/-- Per-element Triple: an unconditional Triple over + `barrett_reduce_element` with the guarded post. The function is + total (see `barrett_reduce_element_eq_ok`); the in-bounds case + invokes the L0.2 `barrett_reduce_element_spec`. -/ +private theorem barrett_per_elem_spec (x : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_element x + ⦃ ⇓ r => ⌜ barrett_per_elem_P x r ⌝ ⦄ := by + -- The function is total: it returns `.ok (barrett_reduce_impl_value x)` + -- unconditionally (`barrett_reduce_element_eq_ok`). Reduce the Triple + -- to a pure goal via that .ok equation. + have h_ok := barrett_reduce_element_eq_ok x + rw [show libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_element x + = .ok (barrett_reduce_impl_value x) from h_ok] + -- Triple shape: ⦃True⦄ .ok v ⦃⇓ r => ⌜barrett_per_elem_P x r⌝⦄. + simp only [Std.Do.Triple, WP.wp] + intro _ + -- Goal: barrett_per_elem_P x (barrett_reduce_impl_value x) + show barrett_per_elem_P x (barrett_reduce_impl_value x) + unfold barrett_per_elem_P + intro hb + -- Now use the L0.2 spec. + have hT := barrett_reduce_element_spec x hb + -- hT : ⦃True⦄ barrett_reduce_element x ⦃⇓ r => ⌜...⌝⦄ + -- Reduce to the .ok form via h_ok. + rw [h_ok] at hT + -- hT : ⦃True⦄ .ok (barrett_reduce_impl_value x) ⦃⇓ r => ⌜...⌝⦄ + -- Extract the post-condition. + simp only [Std.Do.Triple, WP.wp] at hT + exact hT trivial + +@[spec] +theorem barrett_reduce_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bounds : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 32767) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce vec + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + (r.elements.val[i]!).val (vec.elements.val[i]!).val 3329 + ∧ (r.elements.val[i]!).val.natAbs ≤ 3328 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + -- Replace `barrett_reduce_loop.body` with `unary_loop_body + -- barrett_reduce_element`. Both have identical definitions modulo + -- the per_elem variable, so funext is trivial. + have h_body_eq : + (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_loop.body p.1 p.2) + = (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + unary_loop_body + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_element + p.1 p.2) := by + funext p + rcases p with ⟨iter1, vec1⟩ + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_loop.body + unfold unary_loop_body + rfl + rw [h_body_eq] + apply Std.Do.Triple.of_entails_right _ + (elementwise_unary_spec + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_element + barrett_per_elem_P + barrett_per_elem_spec + vec) + rw [PostCond.entails_noThrow] + intro r hh j hj + -- hh : ∀ i < 16, ∃ ri, barrett_reduce_element (vec[i]) = .ok ri + -- ∧ r.elements[i]! = ri ∧ barrett_per_elem_P (vec[i]) ri + obtain ⟨rj, _h_eq, h_acc, h_P⟩ := hh j hj + rw [h_acc] + exact h_P (h_bounds j hj) + +/-! ## L1.4 — `montgomery_multiply_by_constant_spec` + + The Vector.Portable.Arithmetic.montgomery_multiply_by_constant impl + is a 16-iteration loop that calls L0.4's `montgomery_multiply_fe_by_fer` + on each element (the constant `c` is captured by the body lambda). + + Conversion from L0.4's `modq_eq r (fe*c*169) 3329` to L1.4's + `(r*2^16) % 3329 = (fe*c) % 3329`: multiply both sides by 2^16, + and use `169 * 65536 ≡ 1 (mod 3329)` (which is the Montgomery + inversion identity at this q). -/ + +/-- Per-element predicate (unconditional form): the L0.4 bound + + the L1.4-shaped Montgomery congruence. -/ +private def montgomery_per_elem_P (c x y : Std.I16) : Prop := + y.val.natAbs ≤ 3328 + ∧ (y.val * (2 ^ 16 : Int)) % 3329 = (x.val * c.val) % 3329 + +/-- Per-element Triple: invokes the L0.4 spec under the captured `hc`, + then weakens the post via the Montgomery-inversion identity + `169 * 2^16 ≡ 1 (mod 3329)`. -/ +private theorem montgomery_per_elem_spec + (c : Std.I16) (hc : c.val.natAbs ≤ 1664) (x : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_fe_by_fer x c + ⦃ ⇓ r => ⌜ montgomery_per_elem_P c x r ⌝ ⦄ := by + -- Invoke L0.4 and weaken its post to the L1.4-shaped form. + apply Std.Do.Triple.of_entails_right _ (montgomery_multiply_fe_by_fer_spec x c hc) + rw [PostCond.entails_noThrow] + intro r hh + have h_inner : r.val.natAbs ≤ 3328 + ∧ libcrux_iot_ml_kem.Spec.ModularArith.modq_eq r.val (x.val * c.val * 169) 3329 := by + simpa [Std.Do.SPred.down_pure] using hh + obtain ⟨h_bd, h_mod⟩ := h_inner + show montgomery_per_elem_P c x r + unfold montgomery_per_elem_P + refine ⟨h_bd, ?_⟩ + -- From `modq_eq r (x * c * 169) 3329`, derive `(r * 2^16) % 3329 = (x * c) % 3329`. + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq at h_mod + -- h_mod : (r.val - x.val * c.val * 169) % 3329 = 0 + -- Goal: (r.val * 2^16) % 3329 = (x.val * c.val) % 3329 + have h_dvd : (3329 : Int) ∣ (r.val - x.val * c.val * 169) := + Int.dvd_of_emod_eq_zero h_mod + -- 3329 ∣ (r*2^16 - x*c*169*2^16) (multiply prev by 2^16). + have h_dvd2 : (3329 : Int) + ∣ (r.val * (2 ^ 16 : Int) - x.val * c.val * 169 * (2 ^ 16 : Int)) := by + have h_eq : (r.val * (2 ^ 16 : Int) - x.val * c.val * 169 * (2 ^ 16 : Int)) + = (r.val - x.val * c.val * 169) * (2 ^ 16 : Int) := by ring + rw [h_eq]; exact Dvd.dvd.mul_right h_dvd _ + -- 169 * 65536 = 11075584 = 3329 * 3327 + 1, so 169 * 2^16 - 1 = 3329 * 3327. + have h_inv : (169 : Int) * (2 ^ 16 : Int) - 1 = 3329 * 3327 := by decide + have h_dvd3 : (3329 : Int) + ∣ (x.val * c.val * 169 * (2 ^ 16 : Int) - x.val * c.val) := by + have h_eq : (x.val * c.val * 169 * (2 ^ 16 : Int) - x.val * c.val) + = (x.val * c.val) * ((169 : Int) * (2 ^ 16 : Int) - 1) := by ring + rw [h_eq, h_inv] + exact ⟨(x.val * c.val) * 3327, by ring⟩ + have h_dvd4 : (3329 : Int) ∣ (r.val * (2 ^ 16 : Int) - x.val * c.val) := by + have h_sum : (r.val * (2 ^ 16 : Int) - x.val * c.val) + = (r.val * (2 ^ 16 : Int) - x.val * c.val * 169 * (2 ^ 16 : Int)) + + (x.val * c.val * 169 * (2 ^ 16 : Int) - x.val * c.val) := by ring + rw [h_sum]; exact dvd_add h_dvd2 h_dvd3 + -- (a - b) % q = 0 ⇒ a % q = b % q. + rw [Int.emod_eq_emod_iff_emod_sub_eq_zero] + exact Int.emod_eq_zero_of_dvd h_dvd4 + +@[spec] +theorem montgomery_multiply_by_constant_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (c : Std.I16) (hc : c.val.natAbs ≤ 1664) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant vec c + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).val.natAbs ≤ 3328 + ∧ ((r.elements.val[i]!).val * (2 ^ 16 : Int)) % 3329 + = ((vec.elements.val[i]!).val * c.val) % 3329 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + -- Bridge the impl's body shape (captured-c) to `unary_loop_body + -- (λ x => montgomery_multiply_fe_by_fer x c)`. + have h_body_eq : + (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant_loop.body + c p.1 p.2) + = (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + unary_loop_body + (fun x => libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_fe_by_fer x c) + p.1 p.2) := by + funext p + rcases p with ⟨iter1, vec1⟩ + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant_loop.body + unfold unary_loop_body + rfl + rw [h_body_eq] + apply Std.Do.Triple.of_entails_right _ + (elementwise_unary_spec + (fun x => libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_fe_by_fer x c) + (montgomery_per_elem_P c) + (fun x => montgomery_per_elem_spec c hc x) + vec) + rw [PostCond.entails_noThrow] + intro r hh j hj + obtain ⟨rj, _h_eq, h_acc, h_P⟩ := hh j hj + rw [h_acc] + exact h_P + +/-! ## L1.6 — `negate_spec` + + The Vector.Portable.Arithmetic.negate impl is a 16-iter loop + calling `core.num.I16.wrapping_neg` on each element. The + per-element op is the missing-stub `wrapping_sub 0 x`. -/ + +/-- Per-element predicate: `.bv = -x.bv`. -/ +private def negate_per_elem_P (x y : Std.I16) : Prop := + y.bv = -x.bv + +/-- Per-element Triple: `wrapping_neg x = .ok (wrapping_sub 0 x)` and + `(wrapping_sub 0 x).bv = 0 - x.bv = -x.bv`. -/ +private theorem negate_per_elem_spec (x : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.num.I16.wrapping_neg x + ⦃ ⇓ r => ⌜ negate_per_elem_P x r ⌝ ⦄ := by + -- The missing-stub def: `wrapping_neg x = wrapping_sub_i16 0 x = .ok (I16.wrapping_sub 0 x)`. + have h_ok : + CoreModels.core.num.I16.wrapping_neg x = .ok (Aeneas.Std.I16.wrapping_sub (0#i16) x) := by + unfold CoreModels.core.num.I16.wrapping_neg + unfold rust_primitives.arithmetic.wrapping_sub_i16 + rfl + rw [h_ok] + simp only [Std.Do.Triple, WP.wp] + intro _ + show negate_per_elem_P x (Aeneas.Std.I16.wrapping_sub (0#i16) x) + unfold negate_per_elem_P + -- (wrapping_sub 0 x).bv = 0.bv - x.bv = -x.bv. + rw [Aeneas.Std.I16.wrapping_sub_bv_eq] + -- (0#i16).bv = 0 (BitVec definitional); reduce LHS to `0 - x.bv` then + -- apply BitVec.zero_sub. Use simp to normalize the `IScalarTy.I16.numBits` + -- vs `16` type-level reduction (per SKILL §5.1.1). + simp only [show (0#i16 : Std.I16).bv = (0 : BitVec 16) from rfl] + exact BitVec.zero_sub x.bv + +@[spec] +theorem negate_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.negate vec + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).bv = -(vec.elements.val[i]!).bv ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.negate + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.negate_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + -- Bridge body to `unary_loop_body CoreModels.core.num.I16.wrapping_neg`. + have h_body_eq : + (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + libcrux_iot_ml_kem.vector.portable.arithmetic.negate_loop.body p.1 p.2) + = (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + unary_loop_body CoreModels.core.num.I16.wrapping_neg p.1 p.2) := by + funext p + rcases p with ⟨iter1, vec1⟩ + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.negate_loop.body + unfold unary_loop_body + rfl + rw [h_body_eq] + apply Std.Do.Triple.of_entails_right _ + (elementwise_unary_spec + CoreModels.core.num.I16.wrapping_neg + negate_per_elem_P + negate_per_elem_spec + vec) + rw [PostCond.entails_noThrow] + intro r hh j hj + obtain ⟨rj, _h_eq, h_acc, h_P⟩ := hh j hj + rw [h_acc] + exact h_P + +/-! ## L1.5 — `cond_subtract_3329_spec` + + The Vector.Portable.Arithmetic.cond_subtract_3329 impl is a + 16-iter loop where each iter conditionally subtracts 3329 from + the element (if `x ≥ 3329`) or passes through unchanged. The + canonical `unary_loop_body` macro doesn't fit because the + else-branch returns the input `vec` unchanged (no `Array.update`). + + Plan-B: prove from first principles via `loop_range_spec_usize`, + mirroring `elementwise_unary_spec`'s shape (2-conjunct invariant, + body-reduction-to-Result-equation, step lemma) with the + conditional branching inlined. -/ + +namespace CondSubtract3329 + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Result ControlFlow + +private theorem triple_of_ok_l1 + {α : Type} {x : Result α} {v : α} {P : α → Prop} + (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +private theorem of_pure_prop_holds_l1 {P : Prop} + (h : (pure P : Result Prop).holds) : P := by + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, + PredTrans.apply] at h + exact h trivial + +private theorem pure_prop_holds_l1 {P : Prop} (h : P) : (pure P : Result Prop).holds := by + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp]; intro _; exact h + +/-- Per-element invariant for `cond_subtract_3329`. -/ +private def cond_inv + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector → + Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + (((input.elements.val[j]!).val ≥ 3329 ∧ + (acc.elements.val[j]!) = Std.I16.wrapping_sub (input.elements.val[j]!) 3329#i16) + ∨ ((input.elements.val[j]!).val < 3329 ∧ + acc.elements.val[j]! = input.elements.val[j]!))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.elements.val[j]! = input.elements.val[j]!)) + +/-- Per-iteration post as a top-level def (avoids inline match issue + with `loop_range_spec_usize`'s named match constants — see SKILL + §13 / the analogous `unary_step_post` in PortableVector.lean). -/ +private def cond_step_post + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (cond_inv input iter'.start acc').holds + | .done y => (cond_inv input 16#usize y).holds + +set_option maxHeartbeats 8000000 in +private theorem cond_step + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (acc : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (cond_inv input k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329_loop.body + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ cond_step_post input k r ⌝ ⦄ := by + obtain ⟨h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_l1 h_inv + have h_acc_len : acc.elements.length = 16 := PortableVector_elements_length acc + have h_16 : (16#usize : Std.Usize).val = 16 := rfl + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some i = k branch. + have hk_16 : k.val < 16 := by rw [h_16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + have h_idx : + Aeneas.Std.Array.index_usize acc.elements k = .ok (acc.elements.val[k.val]!) := + array_index_usize_ok_eq acc.elements k (by rw [h_acc_len]; exact hk_16) + -- declassify x = .ok x. + have h_decl : + libcrux_secrets.traits.Declassify.Blanket.declassify (acc.elements.val[k.val]!) + = .ok (acc.elements.val[k.val]!) := rfl + -- Two sub-cases for the ≥3329 branch. + set xk : Std.I16 := acc.elements.val[k.val]! with hxk_def + -- The key element at index k. + have h_acc_xk : acc.elements.val[k.val]! = input.elements.val[k.val]! := + h_acc_undone k.val (Nat.le_refl _) hk_16 + by_cases h_ge : xk.val ≥ 3329 + · -- ≥ 3329: write back wrapping_sub. + have h_ge_lit : xk ≥ 3329#i16 := by + -- `≥` on I16 is `.val ≥ .val`. + change (3329#i16 : Std.I16).val ≤ xk.val + have : (3329#i16 : Std.I16).val = 3329 := by decide + rw [this]; exact h_ge + have h_decide : decide (xk ≥ 3329#i16) = true := decide_eq_true h_ge_lit + -- wrapping_sub xk 3329#i16 reduces to .ok (Std.I16.wrapping_sub xk 3329#i16). + have h_wsub : + CoreModels.core.num.I16.wrapping_sub xk 3329#i16 + = .ok (Std.I16.wrapping_sub xk 3329#i16) := by + unfold CoreModels.core.num.I16.wrapping_sub + unfold rust_primitives.arithmetic.wrapping_sub_i16 + rfl + have h_upd : + Aeneas.Std.Array.update acc.elements k (Std.I16.wrapping_sub xk 3329#i16) + = .ok (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16)) := + array_update_ok_eq acc.elements k _ (by rw [h_acc_len]; exact hk_16) + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let i2 ← libcrux_secrets.traits.Declassify.Blanket.declassify i1 + if i2 >= 3329#i16 + then + let i3 ← core.num.I16.wrapping_sub i1 3329#i16 + let a ← Aeneas.Std.Array.update acc.elements i i3 + ok (cont (iter1, { elements := a })) + else ok (cont (iter1, acc))) + = .ok (cont + (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16) })) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [bind_tc_ok] + rw [h_idx] + simp only [bind_tc_ok] + rw [show libcrux_secrets.traits.Declassify.Blanket.declassify xk = .ok xk from rfl] + simp only [bind_tc_ok] + rw [if_pos h_ge_lit] + rw [h_wsub] + simp only [bind_tc_ok] + rw [h_upd] + rfl + apply triple_of_ok_l1 h_body + show cond_step_post input k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16) })) + unfold cond_step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l1 + refine ⟨?_, ?_⟩ + · intro j hj + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · -- j < k: invariant carries over (set at index k, j ≠ k). + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16))[j]! + = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j _ h_ne + have h_set_eq_val : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16)).val[j]! + = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + have h_old := h_acc_done j hj_lt_k + rcases h_old with ⟨h_in_ge, h_acc_eq⟩ | ⟨h_in_lt, h_acc_eq⟩ + · left; refine ⟨h_in_ge, ?_⟩; rw [h_set_eq_val]; exact h_acc_eq + · right; refine ⟨h_in_lt, ?_⟩; rw [h_set_eq_val]; exact h_acc_eq + · -- j = k.val: use the just-set element. + subst hj_eq_k + have h_lt'' : k.val < acc.elements.length := by rw [h_acc_len]; exact hk_16 + have h_set_eq : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16))[k.val]! + = Std.I16.wrapping_sub xk 3329#i16 := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.elements k k.val _ ⟨rfl, h_lt''⟩ + have h_set_eq_val : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16)).val[k.val]! + = Std.I16.wrapping_sub xk 3329#i16 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + left + refine ⟨?_, ?_⟩ + · rw [← h_acc_xk]; exact h_ge + · rw [h_set_eq_val, ← h_acc_xk] + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16))[j]! + = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j _ h_ne + have h_set_eq_val : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16)).val[j]! + = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + rw [h_set_eq_val] + exact h_acc_undone j h_ge' hj_lt + · -- < 3329: pass through unchanged. + have h_not_ge : ¬ (3329#i16 : Std.I16).val ≤ xk.val := by + have h_eq : (3329#i16 : Std.I16).val = 3329 := by decide + rw [h_eq]; exact h_ge + have h_not_ge' : ¬ (xk ≥ 3329#i16) := h_not_ge + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let i2 ← libcrux_secrets.traits.Declassify.Blanket.declassify i1 + if i2 >= 3329#i16 + then + let i3 ← core.num.I16.wrapping_sub i1 3329#i16 + let a ← Aeneas.Std.Array.update acc.elements i i3 + ok (cont (iter1, { elements := a })) + else ok (cont (iter1, acc))) + = .ok (cont + (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc)) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [bind_tc_ok] + rw [h_idx] + simp only [bind_tc_ok] + rw [show libcrux_secrets.traits.Declassify.Blanket.declassify xk = .ok xk from rfl] + simp only [bind_tc_ok] + rw [if_neg h_not_ge'] + apply triple_of_ok_l1 h_body + show cond_step_post input k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc)) + unfold cond_step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l1 + refine ⟨?_, ?_⟩ + · intro j hj + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · exact h_acc_done j hj_lt_k + · subst hj_eq_k + right + refine ⟨?_, ?_⟩ + · rw [← h_acc_xk]; show xk.val < 3329 + push Not at h_ge; exact h_ge + · exact h_acc_xk + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ge' : k.val ≤ j := by omega + exact h_acc_undone j h_ge' hj_lt + · -- None branch. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h_16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let i2 ← libcrux_secrets.traits.Declassify.Blanket.declassify i1 + if i2 >= 3329#i16 + then + let i3 ← core.num.I16.wrapping_sub i1 3329#i16 + let a ← Aeneas.Std.Array.update acc.elements i i3 + ok (cont (iter1, { elements := a })) + else ok (cont (iter1, acc))) + = .ok (done acc) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_l1 h_body + show cond_step_post input k (.done acc) + unfold cond_step_post + apply pure_prop_holds_l1 + refine ⟨?_, ?_⟩ + · intro j hj + apply h_acc_done j + rw [hk_eq]; rw [h_16] at hj; exact hj + · intro j hj_ge hj_lt + apply h_acc_undone j _ hj_lt + rw [hk_eq]; rw [h_16] at hj_ge; exact hj_ge + +end CondSubtract3329 + +@[spec] +theorem cond_subtract_3329_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_bounds : ∀ i : Nat, i < 16 → 0 ≤ (vec.elements.val[i]!).val + ∧ (vec.elements.val[i]!).val < 2 * 3329) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329 vec + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + 0 ≤ (r.elements.val[i]!).val + ∧ (r.elements.val[i]!).val < 3329 + ∧ (r.elements.val[i]!).val % 3329 = (vec.elements.val[i]!).val % 3329 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329 + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, vec1) => + libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329_loop.body + iter1 vec1) + vec 0#usize 16#usize + (CondSubtract3329.cond_inv vec) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (CondSubtract3329.pure_prop_holds_l1 ⟨ + fun j hj => by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0] at hj; exact absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_done, _h_undone⟩ := CondSubtract3329.of_pure_prop_holds_l1 h + intro j hj + -- Per-element: derive 0 ≤ r[j] < 3329 and r[j] % 3329 = vec[j] % 3329. + obtain ⟨h_vec_ge, h_vec_lt⟩ := h_bounds j hj + have h_done_j := h_done j (by rw [show (16#usize : Std.Usize).val = 16 from rfl]; exact hj) + -- Pinned x := vec[j] (since the inv carries `input := vec`). + set xj : Std.I16 := vec.elements.val[j]! with hxj_def + rcases h_done_j with ⟨h_ge, h_eq⟩ | ⟨h_lt, h_eq⟩ + · -- ≥ 3329 branch: r[j] = wrapping_sub xj 3329. + -- (wrapping_sub xj 3329).val = xj.val - 3329, since 3329 ≤ xj.val < 6658. + have h_wsub_val : + (Std.I16.wrapping_sub xj (3329#i16 : Std.I16)).val = xj.val - 3329 := by + rw [Std.I16.wrapping_sub_val_eq] + -- bmod xj.val - 3329 by 2^16 = xj.val - 3329 when in range + have h_diff_lb : -(2^15 : Int) ≤ xj.val - 3329 := by + have h_xj_lb : (0 : Int) ≤ xj.val := h_vec_ge + have : -(2^15 : Int) ≤ -3329 := by decide + grind + have h_diff_ub : xj.val - 3329 < (2^15 : Int) := by + have h_xj_ub : xj.val < 2 * 3329 := h_vec_lt + have h_step : (2 * 3329 - 3329 : Int) < (2^15 : Int) := by decide + grind + have h_3329_val : (3329#i16 : Std.I16).val = 3329 := by decide + rw [h_3329_val] + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int)^(16-1)) ≤ -(2^15 : Int) := by decide + exact le_trans h_const h_diff_lb + · have h_const : (2^15 : Int) ≤ (2 : Int)^(16-1) := by decide + exact lt_of_lt_of_le h_diff_ub h_const + refine ⟨?_, ?_, ?_⟩ + · -- 0 ≤ xj.val - 3329 (since xj.val ≥ 3329). + rw [h_eq, h_wsub_val] + have : (0 : Int) ≤ xj.val - 3329 := by grind + exact this + · -- xj.val - 3329 < 3329 (since xj.val < 2 * 3329). + rw [h_eq, h_wsub_val] + have : xj.val - 3329 < (3329 : Int) := by grind + exact this + · rw [h_eq, h_wsub_val] + -- (xj.val - 3329) % 3329 = xj.val % 3329 (subtracting a multiple of 3329). + have : (xj.val - 3329) % 3329 = xj.val % 3329 := by + have h := Int.sub_emod xj.val 3329 3329 + rw [h] + have h_self : (3329 : Int) % 3329 = 0 := by decide + rw [h_self] + simp [Int.emod_emod_of_dvd] + exact this + · -- < 3329 branch: r[j] = xj. + refine ⟨?_, ?_, ?_⟩ + · rw [h_eq]; exact h_vec_ge + · rw [h_eq]; exact h_lt + · rw [h_eq] + · intro acc k h_ge h_le hinv + have h_step := CondSubtract3329.cond_step vec acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : CondSubtract3329.cond_step_post vec k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [CondSubtract3329.cond_step_post] using hP + · have hP : CondSubtract3329.cond_step_post vec k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [CondSubtract3329.cond_step_post] using hP + +/-! ## L1.7 — `multiply_by_constant_spec` + + The Vector.Portable.Arithmetic.multiply_by_constant impl is a + 16-iteration loop that calls `core.num.I16.wrapping_mul` + on each element, with the constant `c` captured by the body + lambda (same structure as L1.4's `montgomery_multiply_by_constant`). + Under the per-element no-overflow bound + `|x.val * c.val| ≤ 2^15 - 1`, the wrap is a no-op and + `(wrapping_mul x c).val = x.val * c.val`. -/ + +/-- Per-element predicate (guarded form): given the no-overflow bound + on `x * c`, the output value equals the product and is in range. -/ +private def multiply_by_constant_per_elem_P (c x y : Std.I16) : Prop := + (x.val * c.val : Int).natAbs ≤ 2 ^ 15 - 1 → + y.val = x.val * c.val ∧ y.val.natAbs ≤ 2 ^ 15 - 1 + +/-- Per-element Triple: `core.num.I16.wrapping_mul x c` reduces + to `.ok (Std.I16.wrapping_mul x c)`, whose `.val` is the bmod of + `x.val * c.val` mod `2^16`. Under the no-overflow bound, + `Int.bmod` is the identity. -/ +private theorem multiply_by_constant_per_elem_spec + (c : Std.I16) (x : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + CoreModels.core.num.I16.wrapping_mul x c + ⦃ ⇓ r => ⌜ multiply_by_constant_per_elem_P c x r ⌝ ⦄ := by + -- Reduce `wrapping_mul` to `.ok (Std.I16.wrapping_mul x c)`. + have h_ok : + CoreModels.core.num.I16.wrapping_mul x c + = .ok (Aeneas.Std.I16.wrapping_mul x c) := by + unfold CoreModels.core.num.I16.wrapping_mul + unfold rust_primitives.arithmetic.wrapping_mul_i16 + rfl + rw [h_ok] + simp only [Std.Do.Triple, WP.wp] + intro _ + show multiply_by_constant_per_elem_P c x (Aeneas.Std.I16.wrapping_mul x c) + unfold multiply_by_constant_per_elem_P + intro hb + -- `(wrapping_mul x c).val = Int.bmod (x.val * c.val) (2^16)`. + have h_val := Aeneas.Std.I16.wrapping_mul_val_eq x c + -- Under `|x*c| ≤ 2^15 - 1`, Int.bmod is the identity. + -- omega handles `Int.natAbs ≤ N → -N ≤ a ∧ a ≤ N` directly. + have h_lb : -(2 ^ 15 : Int) ≤ x.val * c.val := by + have h_abs : (x.val * c.val : Int).natAbs ≤ 2 ^ 15 - 1 := hb + omega + have h_ub : x.val * c.val < (2 ^ 15 : Int) := by + have h_abs : (x.val * c.val : Int).natAbs ≤ 2 ^ 15 - 1 := hb + omega + have h_bmod : Int.bmod (x.val * c.val) (2 ^ 16) = x.val * c.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + -- Combine. + refine ⟨?_, ?_⟩ + · -- (wrapping_mul x c).val = x.val * c.val + rw [h_val, h_bmod] + · -- (wrapping_mul x c).val.natAbs ≤ 2^15 - 1 + rw [h_val, h_bmod]; exact hb + +@[spec] +theorem multiply_by_constant_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (c : Std.I16) + (hpre : ∀ i : Nat, i < 16 → + ((vec.elements.val[i]!).val * c.val : Int).natAbs ≤ 2 ^ 15 - 1) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.multiply_by_constant vec c + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).val = (vec.elements.val[i]!).val * c.val + ∧ (r.elements.val[i]!).val.natAbs ≤ 2 ^ 15 - 1 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.multiply_by_constant + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.multiply_by_constant_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + -- Bridge `multiply_by_constant_loop.body c` to `unary_loop_body + -- (fun x => core.num.I16.wrapping_mul x c)`. + have h_body_eq : + (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + libcrux_iot_ml_kem.vector.portable.arithmetic.multiply_by_constant_loop.body + c p.1 p.2) + = (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + unary_loop_body + (fun x => core.num.I16.wrapping_mul x c) + p.1 p.2) := by + funext p + rcases p with ⟨iter1, vec1⟩ + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.multiply_by_constant_loop.body + unfold unary_loop_body + rfl + rw [h_body_eq] + apply Std.Do.Triple.of_entails_right _ + (elementwise_unary_spec + (fun x => core.num.I16.wrapping_mul x c) + (multiply_by_constant_per_elem_P c) + (fun x => multiply_by_constant_per_elem_spec c x) + vec) + rw [PostCond.entails_noThrow] + intro r hh j hj + obtain ⟨rj, _h_eq, h_acc, h_P⟩ := hh j hj + rw [h_acc] + exact h_P (hpre j hj) + +/-! ## L1.8 — `bitwise_and_with_constant_spec` + + The Vector.Portable.Arithmetic.bitwise_and_with_constant impl is a + 16-iter loop where each iter computes `i1 &&& c` via the + `lift`-then-bv operation. The per-element op is pure + (no `Result`-level branching beyond `.ok`), so the Triple closes by + direct reduction. -/ + +/-- Per-element predicate: `.bv = x.bv &&& c.bv`. -/ +private def bitwise_and_per_elem_P (c x y : Std.I16) : Prop := + y.bv = x.bv &&& c.bv + +/-- Per-element Triple: `lift (x &&& c)` reduces to `.ok (x &&& c)`, + whose `.bv` is `x.bv &&& c.bv` by definition of `IScalar.and`. -/ +private theorem bitwise_and_per_elem_spec (c : Std.I16) (x : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + lift (x &&& c) + ⦃ ⇓ r => ⌜ bitwise_and_per_elem_P c x r ⌝ ⦄ := by + -- `lift v = .ok v`, definitionally. + have h_ok : (lift (x &&& c) : Result Std.I16) = .ok (x &&& c) := rfl + rw [h_ok] + simp only [Std.Do.Triple, WP.wp] + intro _ + show bitwise_and_per_elem_P c x (x &&& c) + unfold bitwise_and_per_elem_P + -- `(x &&& c).bv = x.bv &&& c.bv` by definition of `IScalar.and`. + rfl + +@[spec] +theorem bitwise_and_with_constant_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (c : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.bitwise_and_with_constant vec c + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).bv = (vec.elements.val[i]!).bv &&& c.bv ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.bitwise_and_with_constant + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.bitwise_and_with_constant_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + have h_body_eq : + (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + libcrux_iot_ml_kem.vector.portable.arithmetic.bitwise_and_with_constant_loop.body + c p.1 p.2) + = (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + unary_loop_body + (fun x => lift (x &&& c)) + p.1 p.2) := by + funext p + rcases p with ⟨iter1, vec1⟩ + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.bitwise_and_with_constant_loop.body + unfold unary_loop_body + rfl + rw [h_body_eq] + apply Std.Do.Triple.of_entails_right _ + (elementwise_unary_spec + (fun x => lift (x &&& c)) + (bitwise_and_per_elem_P c) + (fun x => bitwise_and_per_elem_spec c x) + vec) + rw [PostCond.entails_noThrow] + intro r hh j hj + obtain ⟨rj, _h_eq, h_acc, h_P⟩ := hh j hj + rw [h_acc] + exact h_P + +/-! ## L1.9 — `shift_right_spec` + + The Vector.Portable.Arithmetic.shift_right impl is a 16-iter loop + where each body lifts the captured `SHIFT_BY : I32` to `U32` via + `IScalar.hcast`, then `>>>`-shifts the i-th lane. The per-element + op is `i2 >>> (IScalar.hcast .U32 SHIFT_BY)` = + `IScalar.shiftRight_UScalar i2 _`. The body's structure differs + from `unary_loop_body` in that the `lift (IScalar.hcast …)` + statement is the first do-binding (before `index_usize`); but since + `lift v = .ok v` is a pure no-op with no side effects, swapping + the order is `rfl` after a `simp only [lift]`. -/ + +/-- Per-element predicate: `.bv = x.bv.sshiftRight SHIFT_BY.val.toNat`. -/ +private def shift_right_per_elem_P (SHIFT_BY : Std.I32) (x y : Std.I16) : Prop := + y.bv = x.bv.sshiftRight SHIFT_BY.val.toNat + +/-- Per-element Triple: `x >>> (IScalar.hcast .U32 SHIFT_BY)` reduces + via `IScalar.shiftRight_UScalar_bv_eq` to + `.ok ⟨x.bv.sshiftRight (hcast SHIFT_BY).val⟩`; the hcast preserves + the value when `0 ≤ SHIFT_BY.val < 2^32`, so its `.val = SHIFT_BY.val.toNat`. -/ +private theorem shift_right_per_elem_spec + (SHIFT_BY : Std.I32) (hs : 0 ≤ SHIFT_BY.val ∧ SHIFT_BY.val < 16) (x : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + (x >>> (IScalar.hcast .U32 SHIFT_BY) : Result Std.I16) + ⦃ ⇓ r => ⌜ shift_right_per_elem_P SHIFT_BY x r ⌝ ⦄ := by + -- `x >>> u` unfolds to `IScalar.shiftRight_UScalar x u`. + show ⦃ ⌜ True ⌝ ⦄ + Aeneas.Std.IScalar.shiftRight_UScalar x (IScalar.hcast .U32 SHIFT_BY) + ⦃ ⇓ r => ⌜ shift_right_per_elem_P SHIFT_BY x r ⌝ ⦄ + -- `(hcast .U32 SHIFT_BY).val ≤ 32` needs to be derived from `SHIFT_BY.val < 16`. + obtain ⟨hs_nn, hs_lt⟩ := hs + -- `(hcast .U32 SHIFT_BY).val = SHIFT_BY.val.toNat`. Use the in-bounds + -- spec for `hcast`: under `0 ≤ SHIFT_BY.val ≤ U32.max`, hcast preserves value. + have h_max : SHIFT_BY.val ≤ Aeneas.Std.UScalar.max .U32 := by + -- scalar_tac knows UScalar.max .U32 = 4294967295 (via max_eq simp). + scalar_tac + have h_hcast_spec := Aeneas.Std.IScalar.hcast_inBounds_spec .U32 SHIFT_BY ⟨hs_nn, h_max⟩ + -- The `lift` post-only spec is `spec (lift v) p`, and `lift v = .ok v`, + -- so `spec_ok` collapses it to `p v` directly. + have h_eq_int : ((IScalar.hcast .U32 SHIFT_BY : Std.U32).val : Int) = SHIFT_BY.val := by + -- h_hcast_spec : spec (lift (hcast .U32 SHIFT_BY)) (fun y => y.val = SHIFT_BY.val) + -- Reduce lift → .ok, then spec_ok. + have h_ok_lift : (lift (Aeneas.Std.IScalar.hcast .U32 SHIFT_BY) + : Result Std.U32) + = .ok (Aeneas.Std.IScalar.hcast .U32 SHIFT_BY) := rfl + rw [h_ok_lift] at h_hcast_spec + rw [Aeneas.Std.WP.spec_ok] at h_hcast_spec + exact h_hcast_spec + -- Bridge: `((n : Nat) : Int) = SHIFT_BY.val` and `0 ≤ SHIFT_BY.val` → `n = SHIFT_BY.val.toNat`. + have h_hcast_val : (IScalar.hcast .U32 SHIFT_BY : Std.U32).val = SHIFT_BY.val.toNat := by + have h_inj : ((IScalar.hcast .U32 SHIFT_BY : Std.U32).val : Int).toNat + = SHIFT_BY.val.toNat := by rw [h_eq_int] + simpa using h_inj + -- Now invoke `IScalar.shiftRight_UScalar_bv_eq`. + have h_lt_numBits : (IScalar.hcast .U32 SHIFT_BY : Std.U32).val + < Aeneas.Std.IScalarTy.I16.numBits := by + rw [h_hcast_val] + have h_red : (Aeneas.Std.IScalarTy.I16.numBits : Nat) = 16 := by decide + rw [h_red] + -- `SHIFT_BY.val.toNat < 16` since `SHIFT_BY.val < 16` and `0 ≤ SHIFT_BY.val`. + have h_le : (SHIFT_BY.val.toNat : Int) = SHIFT_BY.val := Int.toNat_of_nonneg hs_nn + omega + have h_sr := IScalar.shiftRight_UScalar_bv_eq x (IScalar.hcast .U32 SHIFT_BY) h_lt_numBits + rw [h_sr] + simp only [Std.Do.Triple, WP.wp] + intro _ + show shift_right_per_elem_P SHIFT_BY x ⟨x.bv.sshiftRight (IScalar.hcast .U32 SHIFT_BY : Std.U32).val⟩ + unfold shift_right_per_elem_P + show (⟨x.bv.sshiftRight (IScalar.hcast .U32 SHIFT_BY : Std.U32).val⟩ : Std.I16).bv + = x.bv.sshiftRight SHIFT_BY.val.toNat + show x.bv.sshiftRight (IScalar.hcast .U32 SHIFT_BY : Std.U32).val + = x.bv.sshiftRight SHIFT_BY.val.toNat + rw [h_hcast_val] + +@[spec] +theorem shift_right_spec + (SHIFT_BY : Std.I32) (hs : 0 ≤ SHIFT_BY.val ∧ SHIFT_BY.val < 16) + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.shift_right SHIFT_BY vec + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).bv = + (vec.elements.val[i]!).bv.sshiftRight SHIFT_BY.val.toNat ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.shift_right + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.shift_right_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + -- Bridge `shift_right_loop.body SHIFT_BY` to + -- `unary_loop_body (fun x => x >>> (IScalar.hcast .U32 SHIFT_BY))`. + -- The two body shapes differ in that the impl computes the hcast + -- *before* the index_usize. Since `lift v = .ok v` is pure, + -- the bind chain reorders by `bind_tc_ok` rewriting. + have h_body_eq : + (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + libcrux_iot_ml_kem.vector.portable.arithmetic.shift_right_loop.body + SHIFT_BY p.1 p.2) + = (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + unary_loop_body + (fun x => x >>> (IScalar.hcast .U32 SHIFT_BY)) + p.1 p.2) := by + funext p + rcases p with ⟨iter1, vec1⟩ + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.shift_right_loop.body + unfold unary_loop_body + -- After unfolding, the impl body has `let i1 ← lift (hcast …); let i2 ← index; let i3 ← i2 >>> i1`, + -- and `unary_loop_body` has `let i1 ← index; let vi ← (i1 >>> hcast …); …`. + -- `lift v = .ok v` is definitional; reduce the leading `lift` bind. + rfl + rw [h_body_eq] + apply Std.Do.Triple.of_entails_right _ + (elementwise_unary_spec + (fun x => x >>> (IScalar.hcast .U32 SHIFT_BY)) + (shift_right_per_elem_P SHIFT_BY) + (fun x => shift_right_per_elem_spec SHIFT_BY hs x) + vec) + rw [PostCond.entails_noThrow] + intro r hh j hj + obtain ⟨rj, _h_eq, h_acc, h_P⟩ := hh j hj + rw [h_acc] + exact h_P + +/-! ## L1.10 — `reducing_from_i32_array_spec` + + The Vector.Portable.Arithmetic.reducing_from_i32_array impl is a + 16-iteration loop that reads `array[i] : I32`, runs L0.3 + `montgomery_reduce_element` (I32 → I16), and writes the result to + `out.elements[i] : I16`. Differs from the unary family in that the + input/output types differ — uses the `io_loop_*` macro family from + `Util/PortableVector.lean`. -/ + +/-- Per-element predicate (guarded form): under the L0.3 precondition + `|x| ≤ 3328 * 2^16`, the output satisfies the L1.10-shaped bound + and Montgomery congruence. -/ +private def reducing_per_elem_P (x : Std.I32) (y : Std.I16) : Prop := + x.val.natAbs ≤ 3328 * 2 ^ 16 → + y.val.natAbs ≤ 3328 + 1665 + ∧ libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + (y.val * (2 ^ 16 : Int)) x.val 3329 + +/-- Per-element Triple: `montgomery_reduce_element` is total (returns + `.ok` unconditionally via `mont_reduce_element_eq_ok`); under the + L0.3 precondition `|x| ≤ 3328 * 2^16`, the post-weakening to L1.10's + Montgomery congruence uses the same identity as L1.4 + (`169 * 2^16 ≡ 1 (mod 3329)`). -/ +private theorem reducing_per_elem_spec (x : Std.I32) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_reduce_element x + ⦃ ⇓ r => ⌜ reducing_per_elem_P x r ⌝ ⦄ := by + -- Reduce to the .ok form via the totality theorem. + have h_ok := mont_reduce_element_eq_ok x + rw [show libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_reduce_element x + = .ok (mont_reduce_impl_value x) from h_ok] + simp only [Std.Do.Triple, WP.wp] + intro _ + show reducing_per_elem_P x (mont_reduce_impl_value x) + unfold reducing_per_elem_P + intro hb + -- Now invoke L0.3 under hb to extract the post. + have hT := montgomery_reduce_element_spec x hb + rw [h_ok] at hT + simp only [Std.Do.Triple, WP.wp] at hT + have h_inner := hT trivial + obtain ⟨h_weak, _h_tight, h_modq⟩ := h_inner + refine ⟨h_weak, ?_⟩ + -- Convert L0.3's `modq_eq r.val (x.val * 169) 3329` to + -- `modq_eq (r.val * 2^16) x.val 3329` (same algebra as L1.4). + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq at h_modq ⊢ + have h_dvd : (3329 : Int) ∣ ((mont_reduce_impl_value x).val - x.val * 169) := + Int.dvd_of_emod_eq_zero h_modq + have h_dvd2 : (3329 : Int) + ∣ ((mont_reduce_impl_value x).val * (2 ^ 16 : Int) + - x.val * 169 * (2 ^ 16 : Int)) := by + have h_eq2 : ((mont_reduce_impl_value x).val * (2 ^ 16 : Int) + - x.val * 169 * (2 ^ 16 : Int)) + = ((mont_reduce_impl_value x).val - x.val * 169) * (2 ^ 16 : Int) := by ring + rw [h_eq2]; exact Dvd.dvd.mul_right h_dvd _ + have h_inv : (169 : Int) * (2 ^ 16 : Int) - 1 = 3329 * 3327 := by decide + have h_dvd3 : (3329 : Int) + ∣ (x.val * 169 * (2 ^ 16 : Int) - x.val) := by + have h_eq3 : (x.val * 169 * (2 ^ 16 : Int) - x.val) + = x.val * ((169 : Int) * (2 ^ 16 : Int) - 1) := by ring + rw [h_eq3, h_inv] + exact ⟨x.val * 3327, by ring⟩ + have h_dvd4 : (3329 : Int) + ∣ ((mont_reduce_impl_value x).val * (2 ^ 16 : Int) - x.val) := by + have h_sum : ((mont_reduce_impl_value x).val * (2 ^ 16 : Int) - x.val) + = ((mont_reduce_impl_value x).val * (2 ^ 16 : Int) + - x.val * 169 * (2 ^ 16 : Int)) + + (x.val * 169 * (2 ^ 16 : Int) - x.val) := by ring + rw [h_sum]; exact dvd_add h_dvd2 h_dvd3 + exact Int.emod_eq_zero_of_dvd h_dvd4 + +@[spec] +theorem reducing_from_i32_array_spec + (array : Aeneas.Std.Slice Std.I32) + (out : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_len : array.val.length = 16) + (hpre : ∀ i : Nat, i < 16 → (array.val[i]!).val.natAbs ≤ 3328 * 2 ^ 16) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.reducing_from_i32_array array out + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).val.natAbs ≤ 3328 + 1665 + ∧ libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + ((r.elements.val[i]!).val * (2 ^ 16 : Int)) + (array.val[i]!).val 3329 ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.reducing_from_i32_array + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.reducing_from_i32_array_loop + -- Bridge `reducing_from_i32_array_loop.body array` to `io_loop_body`. + have h_body_eq : + (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + libcrux_iot_ml_kem.vector.portable.arithmetic.reducing_from_i32_array_loop.body + array p.1 p.2) + = (fun (p : (CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) => + io_loop_body + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_reduce_element + array p.1 p.2) := by + funext p + rcases p with ⟨iter1, vec1⟩ + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.reducing_from_i32_array_loop.body + unfold io_loop_body + rfl + rw [h_body_eq] + have h_len_ge : 16 ≤ array.val.length := by rw [h_len] + apply Std.Do.Triple.of_entails_right _ + (elementwise_io_spec + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_reduce_element + reducing_per_elem_P + reducing_per_elem_spec + array out h_len_ge) + rw [PostCond.entails_noThrow] + intro r hh j hj + obtain ⟨rj, _h_eq, h_acc, h_P⟩ := hh j hj + rw [h_acc] + exact h_P (hpre j hj) + +end libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element +/-! ### Extracted from FCTargets.lean (§vector_arith_hi). -/ + +namespace libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element +open libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §L1 — chunk-level vector ops (10 theorems). -/ + +/-! ### L1.1 — `add` on 16-lane PortableVector chunks. + + Proof sketch: + 1. Bridge `add_pure_val_eq`: `(FieldElement.add_pure a b).val.val + = (a.val.val + b.val.val) % 3329`. Mirrors `Canonical_add_pure`'s + trace through the hacspec U32-widen + add + mod q + U16-narrow body. + 2. Bridge `lift_fe_add_pure_eq`: under the no-overflow bound on + `a.val + b.val` (Int), the `lift_fe` of any i16 carrying that + sum equals `FieldElement.add_pure (lift_fe a) (lift_fe b)`. Both + sides reduce to `feOfZMod ((a.val + b.val : Int) : ZMod 3329)` + via canonical-FE round-trip + `ZMod.natCast_mod`. + 3. Main: extract `add_spec` via `triple_exists_ok_fc` to get + per-element `r[i].val = lhs[i].val + rhs[i].val ∧ bound`. + Apply `triple_of_ok_fc` to close monadic shell. Reduce array + equality to `List.ext_get!` + per-index, then apply Bridge 2. -/ + +/-- Local helper (mirrors `Spec.Pure.uscalar_rem_ok_U32` which is + file-private). Establishes that U32 modular remainder by a non-zero + divisor is always `.ok`, and exposes the underlying value. -/ +theorem uscalar_rem_ok_U32_local (z m : Std.U32) (hm : m.val ≠ 0) : + ∃ w : Std.U32, (z % m : Result Std.U32) = .ok w ∧ w.val = z.val % m.val := by + have heq : (z % m : Result Std.U32) = Std.UScalar.rem z m := rfl + unfold Std.UScalar.rem at heq + simp [hm] at heq + refine ⟨_, heq, ?_⟩ + show (BitVec.umod z.bv m.bv).toNat = z.val % m.val + unfold BitVec.umod + simp only [BitVec.toNat_ofNatLT] + rfl + +/-- Bridge lemma: the `.val.val` of `FieldElement.add_pure` is the + impl's U16 modular-reduced sum. Proof structure mirrors + `Canonical_add_pure` in `Spec.Pure.lean` — chain through the U32 + do-block via `add_equiv` + `uscalar_rem_ok_U32_local` + the U16 + narrowing cast. Pure-projection side lemma, NOT panic-freedom. -/ +theorem add_pure_val_eq + (a b : hacspec_ml_kem.parameters.FieldElement) : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b).val.val + = (a.val.val + b.val.val) % 3329 := by + have hadd : + hacspec_ml_kem.parameters.FieldElement.add a b + = .ok (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure a b) := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_eq_ok a b + unfold hacspec_ml_kem.parameters.FieldElement.add at hadd + simp only [Aeneas.Std.lift, Aeneas.Std.bind_tc_ok] at hadd + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Aeneas.Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hae := Std.UScalar.add_equiv x y + cases hxy : (x + y) with + | ok z => + rw [hxy] at hae hadd; simp at hae + obtain ⟨_, hzval, _⟩ := hae + simp only [Aeneas.Std.bind_tc_ok] at hadd + have hmod_val : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; simp + have hmod_ne : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val ≠ 0 := by + rw [hmod_val]; decide + set m : Std.U32 := Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U32_local z m hmod_ne + rw [hw_eq] at hadd; simp only [Aeneas.Std.bind_tc_ok] at hadd + unfold hacspec_ml_kem.parameters.FieldElement.new at hadd + simp at hadd + have hwbnd : w.val < 3329 := by + rw [hwval, hmod_val]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Aeneas.Std.UScalarTy.numBits]; omega + rw [← hadd] + show (Std.UScalar.cast .U16 w).val = (a.val.val + b.val.val) % 3329 + rw [hwcast, hwval, hmod_val, hzval, hxval, hyval] + | fail e => + rw [hxy] at hae; simp [Std.UScalar.inBounds] at hae + rw [hxval, hyval] at hae; omega + | div => rw [hxy] at hae; exact hae.elim + +/-- Canonical-FE round-trip: a canonical `FieldElement` (i.e. + `fe.val.val < 3329`) is recovered exactly by `feOfZMod ∘ zmodOfFE`. + The forward direction `zmodOfFE_feOfZMod` lives in `Spec.lean`; + this lemma is the canonicity-constrained converse, used to bridge + `FieldElement.add_pure (lift_fe a) (lift_fe b)` (canonical by + `Canonical_add_pure`) to its `feOfZMod`-image normal form. -/ +theorem feOfZMod_zmodOfFE_of_canonical + (fe : hacspec_ml_kem.parameters.FieldElement) + (h : fe.val.val < 3329) : + feOfZMod (zmodOfFE fe) = fe := by + unfold feOfZMod zmodOfFE + -- Goal: ⟨⟨BitVec.ofNat 16 ((fe.val.val : ZMod 3329)).val⟩⟩ = fe. + -- ZMod.val_natCast_of_lt: ((fe.val.val : ZMod 3329)).val = fe.val.val + -- given fe.val.val < 3329. + have hzval : ((fe.val.val : ZMod 3329)).val = fe.val.val := + ZMod.val_natCast_of_lt h + rw [hzval] + -- Goal: ⟨⟨BitVec.ofNat 16 fe.val.val⟩⟩ = fe + -- Both have the same .val.val (= fe.val.val < 65536), so the BV's match. + have hfeval : fe.val.val < 2 ^ 16 := by + have h_p : (3329 : Nat) ≤ 2 ^ 16 := by decide + omega + have hfebv : BitVec.ofNat 16 fe.val.val = fe.val.bv := by + apply BitVec.eq_of_toNat_eq + rw [BitVec.toNat_ofNat] + show fe.val.val % 2 ^ 16 = fe.val.bv.toNat + rw [Nat.mod_eq_of_lt hfeval] + rfl + show ({ val := ⟨BitVec.ofNat 16 fe.val.val⟩ } : + hacspec_ml_kem.parameters.FieldElement) = fe + rw [hfebv] + +/-- Helper: `(lift_fe x).val.val = ((x.val : Int) : ZMod 3329).val`. -/ +theorem lift_fe_val_val (x : Std.I16) : + (lift_fe x).val.val = (((x.val : Int) : ZMod 3329)).val := by + unfold lift_fe i16_to_spec_fe_plain feOfZMod + show (BitVec.ofNat 16 (((x.val : Int) : ZMod 3329)).val).toNat + = (((x.val : Int) : ZMod 3329)).val + rw [BitVec.toNat_ofNat] + have h_lt : (((x.val : Int) : ZMod 3329)).val < 2 ^ 16 := + Nat.lt_of_lt_of_le (ZMod.val_lt _) (by decide) + exact Nat.mod_eq_of_lt h_lt + +/-- Bridge lemma: under the no-overflow bound on `a.val + b.val` + (Int, |·| ≤ 2^15-1), any `r : Std.I16` carrying that sum lifts to + `FieldElement.add_pure (lift_fe a) (lift_fe b)`. + + Pure-projection content: both sides reduce to + `feOfZMod ((a.val + b.val : Int) : ZMod 3329)`. The LHS is direct + from `lift_fe`'s definition. The RHS uses `add_pure_val_eq` plus + canonical round-trip — the result is canonical by + `Canonical_add_pure`, so equals `feOfZMod (zmodOfFE …)`, and the + `zmodOfFE`-projection reduces by `ZMod.natCast_mod` to the desired + cast sum. -/ +theorem lift_fe_add_pure_eq + (a b r : Std.I16) (hrv : r.val = a.val + b.val) : + lift_fe r + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe a) (lift_fe b) := by + -- LHS reduction: lift_fe r = feOfZMod ((r.val : Int) : ZMod 3329) + -- = feOfZMod ((a.val + b.val : Int) : ZMod 3329). + have h_lhs : lift_fe r + = feOfZMod (((a.val + b.val : Int)) : ZMod 3329) := by + unfold lift_fe i16_to_spec_fe_plain + rw [hrv] + -- RHS reduction: FieldElement.add_pure (lift_fe a) (lift_fe b) + -- = feOfZMod (zmodOfFE …) (canonical round-trip) + -- = feOfZMod ((a.val + b.val : Int) : ZMod 3329). + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe a) (lift_fe b) with hs_def + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_add_pure + (lift_fe a) (lift_fe b) + show s.val.val < 3329 + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + -- zmodOfFE s = ((a.val + b.val : Int) : ZMod 3329). + have h_zmod_s : zmodOfFE s = (((a.val + b.val : Int)) : ZMod 3329) := by + unfold zmodOfFE + -- (s.val.val : ZMod 3329) with s.val.val = ((lift_fe a).val.val + (lift_fe b).val.val) % 3329 + rw [add_pure_val_eq] + -- Goal: ((((lift_fe a).val.val + (lift_fe b).val.val) % 3329 : Nat) : ZMod 3329) + -- = ((a.val + b.val : Int) : ZMod 3329). + rw [ZMod.natCast_mod] + push_cast + rw [lift_fe_val_val a, lift_fe_val_val b] + -- Goal: (((a.val : Int) : ZMod 3329).val : ZMod 3329) + -- + (((b.val : Int) : ZMod 3329).val : ZMod 3329) + -- = ((a.val + b.val : Int) : ZMod 3329). + rw [ZMod.natCast_zmod_val, ZMod.natCast_zmod_val] + rw [h_lhs, ← h_round_trip, h_zmod_s] + +/-- L1.1 — `add` on 16-lane PortableVector chunks. -/ +@[spec high] +theorem add_fc + (lhs rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hpre : ∀ i : Nat, i < 16 → + ((lhs.elements.val[i]!).val + (rhs.elements.val[i]!).val : Int).natAbs ≤ 2^15 - 1) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.add lhs rhs + ⦃ ⇓ r => ⌜ lift_chunk r = Spec.chunk_add_pure (lift_chunk lhs) (lift_chunk rhs) ⌝ ⦄ := by + -- 1. Extract per-element value-equation from legacy bounds Triple. + have h_legacy := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.add_spec lhs rhs hpre + obtain ⟨r0, h_eq, h_per⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + -- 2. Reduce array equality to list equality, then to per-index lift_fe equality. + unfold lift_chunk Spec.chunk_add_pure + apply Subtype.ext + show r0.elements.val.map lift_fe + = (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((Std.Array.make 16#usize (lhs.elements.val.map lift_fe) + (by simp)).val[i]!) + ((Std.Array.make 16#usize (rhs.elements.val.map lift_fe) + (by simp)).val[i]!)) + -- 3. Show both lists have length 16 and per-index entries match. + have h_r0_len : r0.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length r0 + apply List.ext_getElem + · simp [List.length_map, List.length_range, h_r0_len] + · intro i hi1 hi2 + -- LHS at i: lift_fe (r0.elements.val[i]). + -- RHS at i: add_pure (lift_fe lhs[i]) (lift_fe rhs[i]). + have hi : i < 16 := by + have : i < (r0.elements.val.map lift_fe).length := hi1 + simp [List.length_map, h_r0_len] at this; exact this + -- Rewrite LHS. + rw [List.getElem_map] + -- Rewrite RHS. + rw [List.getElem_map, List.getElem_range] + -- Indexing into Std.Array.make. + show lift_fe r0.elements.val[i] + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((lhs.elements.val.map lift_fe)[i]!) + ((rhs.elements.val.map lift_fe)[i]!) + -- Express `r0.elements.val[i]` as `r0.elements.val[i]!` + -- (same value when i < length). + have h_r0_get_eq : r0.elements.val[i] + = r0.elements.val[i]! := by + have hi_r0 : i < r0.elements.val.length := by rw [h_r0_len]; exact hi + rw [getElem!_pos r0.elements.val i hi_r0] + rw [h_r0_get_eq] + -- Express `(lhs.elements.val.map lift_fe)[i]!` as `lift_fe (lhs.val[i]!)`. + have h_lhs_len : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + have h_rhs_len : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + have h_map_lhs : + (lhs.elements.val.map lift_fe)[i]! = lift_fe (lhs.elements.val[i]!) := by + have hi_lhs : i < lhs.elements.val.length := by rw [h_lhs_len]; exact hi + rw [getElem!_pos (lhs.elements.val.map lift_fe) i (by + simp [List.length_map, h_lhs_len]; exact hi)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val i hi_lhs] + have h_map_rhs : + (rhs.elements.val.map lift_fe)[i]! = lift_fe (rhs.elements.val[i]!) := by + have hi_rhs : i < rhs.elements.val.length := by rw [h_rhs_len]; exact hi + rw [getElem!_pos (rhs.elements.val.map lift_fe) i (by + simp [List.length_map, h_rhs_len]; exact hi)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val i hi_rhs] + rw [h_map_lhs, h_map_rhs] + -- Apply the bridge lemma with the per-element value equation. + obtain ⟨h_val, _h_bnd⟩ := h_per i hi + exact lift_fe_add_pure_eq + (lhs.elements.val[i]!) (rhs.elements.val[i]!) (r0.elements.val[i]!) + h_val + +/-- Canonicity of `lift_fe`: every `feOfZMod`-image is canonical + (its `.val.val < 3329`). Used by `lift_fe_sub_pure_eq` to discharge + `sub_eq_ok`'s canonicity preconditions. Pure-projection side lemma. -/ +theorem Canonical_lift_fe (x : Std.I16) : + libcrux_iot_ml_kem.Spec.Pure.Canonical (lift_fe x) := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical + unfold lift_fe i16_to_spec_fe_plain feOfZMod + -- Goal: (BitVec.ofNat 16 ((x.val : Int) : ZMod 3329).val).toNat + -- < parameters.FIELD_MODULUS.val + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] + show (BitVec.ofNat 16 (((x.val : Int) : ZMod 3329)).val).toNat < 3329 + rw [BitVec.toNat_ofNat] + have h_lt : (((x.val : Int) : ZMod 3329)).val < 3329 := ZMod.val_lt _ + have h_le : (((x.val : Int) : ZMod 3329)).val < 2 ^ 16 := + Nat.lt_of_lt_of_le h_lt (by decide) + rw [Nat.mod_eq_of_lt h_le] + exact h_lt + +/-- Every lane of `lift_poly self` is canonical (as a hacspec FE), since each + lane is a `lift_fe` image and `lift_fe` produces canonical FEs via + `feOfZMod`. Used by L6.1 FC close to feed `poly_barrett_reduce_pure_id_of_canonical`. -/ +theorem lift_poly_lanes_canonical + (self : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + ∀ k : Nat, k < 256 → + libcrux_iot_ml_kem.Spec.Pure.Canonical ((lift_poly self).val[k]!) := by + intro k hk + -- (lift_poly self).val = (List.range 256).map (fun j => lift_fe …). + unfold lift_poly + show libcrux_iot_ml_kem.Spec.Pure.Canonical + (((List.range 256).map (fun j => + lift_fe (self.coefficients.val[j / 16]!).elements.val[j % 16]!))[k]!) + have h_len : ((List.range 256).map (fun j => + lift_fe (self.coefficients.val[j / 16]!).elements.val[j % 16]!)).length = 256 := by + simp + rw [getElem!_pos _ k (by rw [h_len]; exact hk)] + rw [List.getElem_map, List.getElem_range] + exact Canonical_lift_fe _ + +/-- Bridge lemma: the `.val.val` of `FieldElement.sub_pure` (under + canonicity of both operands) is the impl's U16 modular-reduced + difference: `(a.val.val + 3329 - b.val.val) % 3329`. Mirrors + `add_pure_val_eq`'s trace through the U32 do-block; the impl's + `sub` body is `(self.val + q - other.val) % q` (`x + q` widens + safely under `b` canonical, then `s - y ≥ 0` since + `s = x + q ≥ q > y`, then `% q`, then narrow U16). -/ +theorem sub_pure_val_eq + (a b : hacspec_ml_kem.parameters.FieldElement) + (ha : libcrux_iot_ml_kem.Spec.Pure.Canonical a) + (hb : libcrux_iot_ml_kem.Spec.Pure.Canonical b) : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure a b).val.val + = (a.val.val + 3329 - b.val.val) % 3329 := by + have hsub : + hacspec_ml_kem.parameters.FieldElement.sub a b + = .ok (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure a b) := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_eq_ok a b ha hb + have ha' : a.val.val < 3329 := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at ha + unfold hacspec_ml_kem.parameters.FIELD_MODULUS at ha; simpa using ha + have hb' : b.val.val < 3329 := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at hb + unfold hacspec_ml_kem.parameters.FIELD_MODULUS at hb; simpa using hb + unfold hacspec_ml_kem.parameters.FieldElement.sub at hsub + simp only [Aeneas.Std.lift, Aeneas.Std.bind_tc_ok] at hsub + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Aeneas.Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + set q : Std.U32 := Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hqval : q.val = 3329 := by + show (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val = 3329 + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; simp + have hae := Std.UScalar.add_equiv x q + cases hxq : (x + q : Result Std.U32) with + | ok s => + rw [hxq] at hae hsub; simp at hae + obtain ⟨_, hsval, _⟩ := hae + simp only [Aeneas.Std.bind_tc_ok] at hsub + have hae2 := Std.UScalar.sub_equiv s y + cases hsy : (s - y : Result Std.U32) with + | ok u => + rw [hsy] at hae2 hsub; simp at hae2 + -- hae2 : y.val ≤ s.val ∧ s.val = u.val + y.val ∧ u.bv = s.bv - y.bv + obtain ⟨_hyle, hsuy, _⟩ := hae2 + simp only [Aeneas.Std.bind_tc_ok] at hsub + have hq_ne : q.val ≠ 0 := by rw [hqval]; decide + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U32_local u q hq_ne + rw [hw_eq] at hsub; simp only [Aeneas.Std.bind_tc_ok] at hsub + unfold hacspec_ml_kem.parameters.FieldElement.new at hsub + simp at hsub + have hwbnd : w.val < 3329 := by + rw [hwval, hqval]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Aeneas.Std.UScalarTy.numBits]; omega + rw [← hsub] + show (Std.UScalar.cast .U16 w).val = (a.val.val + 3329 - b.val.val) % 3329 + rw [hwcast, hwval, hqval] + -- Goal: u.val % 3329 = (a.val.val + 3329 - b.val.val) % 3329 + -- From hsuy : s.val = u.val + y.val and hsval : s.val = x.val + q.val + -- and hxval, hqval, hyval, hb', we get u.val = a.val.val + 3329 - b.val.val. + have hu_eq : u.val = a.val.val + 3329 - b.val.val := by + have h1 : s.val = u.val + y.val := hsuy + rw [hsval, hxval, hqval, hyval] at h1 + omega + rw [hu_eq] + | fail e => + rw [hsy] at hae2; simp at hae2 + rw [hsval, hxval, hqval, hyval] at hae2 + omega + | div => rw [hsy] at hae2; exact hae2.elim + | fail e => + rw [hxq] at hae; simp [Std.UScalar.inBounds] at hae + rw [hxval, hqval] at hae + omega + | div => rw [hxq] at hae; exact hae.elim + +/-- Bridge lemma: the `.val.val` of `FieldElement.mul_pure` is the + impl's U16 modular-reduced product: `(a.val.val * b.val.val) % 3329`. + Mirrors `add_pure_val_eq`'s trace through the U32 do-block; the + impl's `mul` body is `(self.val * other.val) % q` (`x * y` widens + safely to U32 since `a.val * b.val ≤ (2^16-1)² < 2^32`, then `% q`, + then narrow U16). Unconditional (no canonicity needed). -/ +theorem mul_pure_val_eq + (a b : hacspec_ml_kem.parameters.FieldElement) : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b).val.val + = (a.val.val * b.val.val) % 3329 := by + have hmul : + hacspec_ml_kem.parameters.FieldElement.mul a b + = .ok (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure a b) := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_eq_ok a b + unfold hacspec_ml_kem.parameters.FieldElement.mul at hmul + simp only [Aeneas.Std.lift, Aeneas.Std.bind_tc_ok] at hmul + have hA := a.val.hBounds; have hB := b.val.hBounds + simp [Aeneas.Std.UScalarTy.numBits] at hA hB + set x : Std.U32 := Std.UScalar.cast .U32 a.val + set y : Std.U32 := Std.UScalar.cast .U32 b.val + have hxval : x.val = a.val.val := Std.U16.cast_U32_val_eq a.val + have hyval : y.val = b.val.val := Std.U16.cast_U32_val_eq b.val + have hae := Std.UScalar.mul_equiv x y + have heqmul : (x * y : Result Std.U32) = Std.UScalar.mul x y := rfl + cases hxy : (x * y : Result Std.U32) with + | ok z => + rw [hxy] at hmul + rw [heqmul] at hxy; rw [hxy] at hae; simp at hae + obtain ⟨_, hzval, _⟩ := hae + simp only [Aeneas.Std.bind_tc_ok] at hmul + have hmod_val : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; simp + have hmod_ne : + (Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS).val ≠ 0 := by + rw [hmod_val]; decide + set m : Std.U32 := Std.UScalar.cast .U32 hacspec_ml_kem.parameters.FIELD_MODULUS + obtain ⟨w, hw_eq, hwval⟩ := uscalar_rem_ok_U32_local z m hmod_ne + rw [hw_eq] at hmul; simp only [Aeneas.Std.bind_tc_ok] at hmul + unfold hacspec_ml_kem.parameters.FieldElement.new at hmul + simp at hmul + have hwbnd : w.val < 3329 := by + rw [hwval, hmod_val]; exact Nat.mod_lt _ (by decide) + have hwcast : (Std.UScalar.cast .U16 w).val = w.val := by + apply Std.UScalar.cast_val_mod_pow_of_inBounds_eq + simp [Aeneas.Std.UScalarTy.numBits]; omega + rw [← hmul] + show (Std.UScalar.cast .U16 w).val = (a.val.val * b.val.val) % 3329 + rw [hwcast, hwval, hmod_val, hzval, hxval, hyval] + | fail _ => + rw [heqmul] at hxy; rw [hxy] at hae + simp only [Std.UScalar.max, Aeneas.Std.UScalarTy.numBits] at hae + rw [hxval, hyval] at hae + have : a.val.val * b.val.val < 2^32 := by + have h1 : a.val.val * b.val.val ≤ (2^16 - 1) * (2^16 - 1) := by + apply Nat.mul_le_mul <;> omega + have heq : (2^16 - 1) * (2^16 - 1) = 2^32 - 2*2^16 + 1 := by decide + omega + omega + | div => rw [heqmul] at hxy; rw [hxy] at hae; exact hae.elim + +/-- Bridge lemma: under the no-overflow bound on `a.val - b.val` (Int, + |·| ≤ 2^15-1), any `r : Std.I16` carrying that difference lifts to + `FieldElement.sub_pure (lift_fe a) (lift_fe b)`. + + Pure-projection content: both sides reduce to + `feOfZMod ((a.val - b.val : Int) : ZMod 3329)`. The RHS uses + `sub_pure_val_eq` plus canonical round-trip — the result is + canonical by `Canonical_sub_pure`, and the `zmodOfFE`-projection + reduces the inner `(a.val.val + 3329 - b.val.val) % 3329` to + `(a.val - b.val : Int) : ZMod 3329` via `ZMod.natCast_mod` plus + integer reasoning. -/ +theorem lift_fe_sub_pure_eq + (a b r : Std.I16) (hrv : r.val = a.val - b.val) : + lift_fe r + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe a) (lift_fe b) := by + have h_lhs : lift_fe r + = feOfZMod (((a.val - b.val : Int)) : ZMod 3329) := by + unfold lift_fe i16_to_spec_fe_plain + rw [hrv] + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe a) (lift_fe b) with hs_def + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_sub_pure + (lift_fe a) (lift_fe b) (Canonical_lift_fe a) (Canonical_lift_fe b) + show s.val.val < 3329 + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + have h_zmod_s : zmodOfFE s = (((a.val - b.val : Int)) : ZMod 3329) := by + unfold zmodOfFE + rw [sub_pure_val_eq _ _ (Canonical_lift_fe a) (Canonical_lift_fe b)] + -- Goal: (((lift_fe a).val.val + 3329 - (lift_fe b).val.val) % 3329 : Nat) : ZMod 3329 + -- = ((a.val - b.val : Int) : ZMod 3329) + rw [ZMod.natCast_mod] + -- Step: lift the Nat-subtraction expression through Nat→ZMod cast using + -- a cast-equality. Since (lift_fe b).val.val ≤ (lift_fe a).val.val + 3329 + -- (b canonical), Nat subtraction agrees with Int subtraction. + have hb_lt : (lift_fe b).val.val < 3329 := by + have h_cb := Canonical_lift_fe b + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cb + unfold hacspec_ml_kem.parameters.FIELD_MODULUS at h_cb; simpa using h_cb + have h_le : (lift_fe b).val.val ≤ (lift_fe a).val.val + 3329 := by omega + -- (Nat-cast into ZMod) of the Nat sub equals (Int-cast into ZMod) of the Int sub. + have h_zmod_eq : + (((lift_fe a).val.val + 3329 - (lift_fe b).val.val : Nat) : ZMod 3329) + = ((((lift_fe a).val.val : Int) + 3329 - ((lift_fe b).val.val : Int) : Int) + : ZMod 3329) := by + have h_int_eq : + (((lift_fe a).val.val + 3329 - (lift_fe b).val.val : Nat) : Int) + = ((lift_fe a).val.val : Int) + 3329 - ((lift_fe b).val.val : Int) := by + omega + have h_route : + (((lift_fe a).val.val + 3329 - (lift_fe b).val.val : Nat) : ZMod 3329) + = ((((lift_fe a).val.val + 3329 - (lift_fe b).val.val : Nat) : Int) + : ZMod 3329) := by + rfl + rw [h_route, h_int_eq] + rw [h_zmod_eq] + push_cast + rw [lift_fe_val_val a, lift_fe_val_val b] + rw [ZMod.natCast_zmod_val, ZMod.natCast_zmod_val] + -- Goal after push_cast: ((a.val : Int) : ZMod 3329) + 0 - ((b.val : Int) : ZMod 3329) + -- = ((a.val - b.val : Int) : ZMod 3329) + -- (push_cast collapses `(3329 : ZMod 3329)` to `0` via ZMod.natCast_self.) + ring + rw [h_lhs, ← h_round_trip, h_zmod_s] + +/-- Bridge lemma: under the no-overflow bound on `a.val * b.val` (Int, + |·| ≤ 2^15-1), any `r : Std.I16` carrying that product lifts to + `FieldElement.mul_pure (lift_fe a) (lift_fe b)`. + + Pure-projection content: both sides reduce to + `feOfZMod ((a.val * b.val : Int) : ZMod 3329)`. The RHS uses + `mul_pure_val_eq` plus canonical round-trip — the result is + canonical by `Canonical_mul_pure`. Unconditional in canonicity. -/ +theorem lift_fe_mul_pure_eq + (a b r : Std.I16) (hrv : r.val = a.val * b.val) : + lift_fe r + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe a) (lift_fe b) := by + have h_lhs : lift_fe r + = feOfZMod (((a.val * b.val : Int)) : ZMod 3329) := by + unfold lift_fe i16_to_spec_fe_plain + rw [hrv] + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe a) (lift_fe b) with hs_def + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure + (lift_fe a) (lift_fe b) + show s.val.val < 3329 + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + have h_zmod_s : zmodOfFE s = (((a.val * b.val : Int)) : ZMod 3329) := by + unfold zmodOfFE + rw [mul_pure_val_eq] + rw [ZMod.natCast_mod] + push_cast + rw [lift_fe_val_val a, lift_fe_val_val b] + rw [ZMod.natCast_zmod_val, ZMod.natCast_zmod_val] + rw [h_lhs, ← h_round_trip, h_zmod_s] + +/-- L1.2 — `sub` on 16-lane PortableVector chunks. -/ +@[spec high] +theorem sub_fc + (lhs rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hpre : ∀ i : Nat, i < 16 → + ((lhs.elements.val[i]!).val - (rhs.elements.val[i]!).val : Int).natAbs ≤ 2^15 - 1) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.sub lhs rhs + ⦃ ⇓ r => ⌜ lift_chunk r = Spec.chunk_sub_pure (lift_chunk lhs) (lift_chunk rhs) ⌝ ⦄ := by + -- 1. Extract per-element value-equation from legacy bounds Triple. + have h_legacy := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.sub_spec lhs rhs hpre + obtain ⟨r0, h_eq, h_per⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + -- 2. Reduce array equality to list equality, then to per-index lift_fe equality. + unfold lift_chunk Spec.chunk_sub_pure + apply Subtype.ext + show r0.elements.val.map lift_fe + = (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((Std.Array.make 16#usize (lhs.elements.val.map lift_fe) + (by simp)).val[i]!) + ((Std.Array.make 16#usize (rhs.elements.val.map lift_fe) + (by simp)).val[i]!)) + have h_r0_len : r0.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length r0 + apply List.ext_getElem + · simp [List.length_map, List.length_range, h_r0_len] + · intro i hi1 hi2 + have hi : i < 16 := by + have : i < (r0.elements.val.map lift_fe).length := hi1 + simp [List.length_map, h_r0_len] at this; exact this + rw [List.getElem_map] + rw [List.getElem_map, List.getElem_range] + show lift_fe r0.elements.val[i] + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((lhs.elements.val.map lift_fe)[i]!) + ((rhs.elements.val.map lift_fe)[i]!) + have h_r0_get_eq : r0.elements.val[i] + = r0.elements.val[i]! := by + have hi_r0 : i < r0.elements.val.length := by rw [h_r0_len]; exact hi + rw [getElem!_pos r0.elements.val i hi_r0] + rw [h_r0_get_eq] + have h_lhs_len : lhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length lhs + have h_rhs_len : rhs.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length rhs + have h_map_lhs : + (lhs.elements.val.map lift_fe)[i]! = lift_fe (lhs.elements.val[i]!) := by + have hi_lhs : i < lhs.elements.val.length := by rw [h_lhs_len]; exact hi + rw [getElem!_pos (lhs.elements.val.map lift_fe) i (by + simp [List.length_map, h_lhs_len]; exact hi)] + rw [List.getElem_map] + rw [getElem!_pos lhs.elements.val i hi_lhs] + have h_map_rhs : + (rhs.elements.val.map lift_fe)[i]! = lift_fe (rhs.elements.val[i]!) := by + have hi_rhs : i < rhs.elements.val.length := by rw [h_rhs_len]; exact hi + rw [getElem!_pos (rhs.elements.val.map lift_fe) i (by + simp [List.length_map, h_rhs_len]; exact hi)] + rw [List.getElem_map] + rw [getElem!_pos rhs.elements.val i hi_rhs] + rw [h_map_lhs, h_map_rhs] + obtain ⟨h_val, _h_bnd⟩ := h_per i hi + exact lift_fe_sub_pure_eq + (lhs.elements.val[i]!) (rhs.elements.val[i]!) (r0.elements.val[i]!) + h_val + +/-- Per-element bridge for `barrett_reduce_fc`: under `modq_eq r vec 3329`, + the lift of `r` equals the spec-side `Spec.barrett_pure` applied to + the lift of `vec`. Combines `lift_fe_eq_of_modq` with + `barrett_pure_lift_fe` (which collapses the canonical round-trip on + `lift_fe` images to identity). -/ +theorem lift_fe_barrett_pure_eq + (a r : Std.I16) + (h : libcrux_iot_ml_kem.Spec.ModularArith.modq_eq r.val a.val 3329) : + lift_fe r = Spec.barrett_pure (lift_fe a) := by + rw [barrett_pure_lift_fe] + exact lift_fe_eq_of_modq r a h + +/-- L1.3 — `barrett_reduce` on a chunk. -/ +@[spec high] +theorem barrett_reduce_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hpre : ∀ i : Nat, i < 16 → + (vec.elements.val[i]!).val.natAbs ≤ 32767) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce vec + ⦃ ⇓ r => ⌜ (∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ 3328) + ∧ lift_chunk r = Spec.chunk_barrett_reduce_pure (lift_chunk vec) ⌝ ⦄ := by + -- 1. Extract per-element legacy fact: modq_eq r[i] vec[i] 3329 ∧ |r[i]| ≤ 3328. + have h_legacy := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.barrett_reduce_spec vec hpre + obtain ⟨r0, h_eq, h_per⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + refine ⟨?_, ?_⟩ + · -- Bound conjunct: extract `r[i].natAbs ≤ 3328` from per-element legacy. + intro i hi + exact (h_per i hi).2 + · -- 2. Reduce array equality to list equality, then to per-index lift_fe equality. + unfold lift_chunk Spec.chunk_barrett_reduce_pure + apply Subtype.ext + show r0.elements.val.map lift_fe + = (List.range 16).map (fun i => + Spec.barrett_pure + ((Std.Array.make 16#usize (vec.elements.val.map lift_fe) + (by simp)).val[i]!)) + have h_r0_len : r0.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length r0 + apply List.ext_getElem + · simp [List.length_map, List.length_range, h_r0_len] + · intro i hi1 hi2 + have hi : i < 16 := by + have : i < (r0.elements.val.map lift_fe).length := hi1 + simp [List.length_map, h_r0_len] at this; exact this + rw [List.getElem_map] + rw [List.getElem_map, List.getElem_range] + show lift_fe r0.elements.val[i] + = Spec.barrett_pure ((vec.elements.val.map lift_fe)[i]!) + have h_r0_get_eq : r0.elements.val[i] + = r0.elements.val[i]! := by + have hi_r0 : i < r0.elements.val.length := by rw [h_r0_len]; exact hi + rw [getElem!_pos r0.elements.val i hi_r0] + rw [h_r0_get_eq] + have h_vec_len : vec.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length vec + have h_map_vec : + (vec.elements.val.map lift_fe)[i]! = lift_fe (vec.elements.val[i]!) := by + have hi_vec : i < vec.elements.val.length := by rw [h_vec_len]; exact hi + rw [getElem!_pos (vec.elements.val.map lift_fe) i (by + simp [List.length_map, h_vec_len]; exact hi)] + rw [List.getElem_map] + rw [getElem!_pos vec.elements.val i hi_vec] + rw [h_map_vec] + obtain ⟨h_modq, _h_bnd⟩ := h_per i hi + exact lift_fe_barrett_pure_eq + (vec.elements.val[i]!) (r0.elements.val[i]!) h_modq + +/-- Per-element bridge for `montgomery_multiply_by_constant_fc`: from the + legacy L1.4 congruence `(r * 2^16) ≡ (a * c) (mod 3329)`, derive the + FC equation + `lift_fe_mont r = Spec.montgomery_multiply_fe_by_fer_pure (lift_fe a) (lift_fe_mont c)`. + + Algebra: the goal (after unfolding via `mmfbf_pure_lift_fe_lift_fe_mont` + and `lift_fe_mont`/`i16_to_spec_fe_mont`) is the `ZMod 3329` equation + `r * 169 = a * (c * 169) * 169`. From the legacy hypothesis + `r * 2^16 = a * c` in `ZMod 3329`, multiply both sides by `169 * 169` + and use the Montgomery-inversion identity `2^16 * 169 = 1` in `ZMod 3329` + to collapse one factor on the LHS. -/ +theorem lift_fe_mont_mmfbf_pure_eq + (a c r : Std.I16) + (h : (r.val * (2 ^ 16 : Int)) % 3329 = (a.val * c.val) % 3329) : + lift_fe_mont r + = Spec.montgomery_multiply_fe_by_fer_pure (lift_fe a) (lift_fe_mont c) := by + rw [mmfbf_pure_lift_fe_lift_fe_mont] + unfold lift_fe_mont i16_to_spec_fe_mont + congr 1 + -- Goal: (r.val : ZMod 3329) * 169 = (a.val : ZMod 3329) * ((c.val : ZMod 3329) * 169) * 169 + have h_modq : libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + (r.val * (2 ^ 16 : Int)) (a.val * c.val) 3329 := by + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + rw [Int.sub_emod, h]; simp + have h_zmod : ((r.val * (2 ^ 16 : Int) : Int) : ZMod 3329) + = ((a.val * c.val : Int) : ZMod 3329) := + modq_eq_cast_zmod _ _ h_modq + push_cast at h_zmod + -- After push_cast: h_zmod : (r.val : ZMod 3329) * 2285 = (a.val : ZMod 3329) * (c.val : ZMod 3329) + -- (Lean reduces `2^16` to its canonical residue `2285` in ZMod 3329.) + -- Goal: (r.val : ZMod 3329) * 169 = (a.val : ZMod 3329) * ((c.val : ZMod 3329) * 169) * 169 + -- Normalize: 2285 * 169 = 386165 = 3329 * 116 + 1, so 2285 * 169 ≡ 1 (mod 3329). + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + calc (r.val : ZMod 3329) * 169 + = (r.val : ZMod 3329) * ((2285 : ZMod 3329) * 169) * 169 := by rw [h_inv]; ring + _ = ((r.val : ZMod 3329) * 2285) * 169 * 169 := by ring + _ = ((a.val : ZMod 3329) * (c.val : ZMod 3329)) * 169 * 169 := by rw [h_zmod] + _ = (a.val : ZMod 3329) * ((c.val : ZMod 3329) * 169) * 169 := by ring + +/-- Per-element bridge for the `montgomery_multiply_by_constant`-then-other + chain used by `subtract_reduce_fc`: from the legacy L1.4 congruence + `(r * 2^16) ≡ (a * c) (mod 3329)`, derive the *plain*-domain FC + equation `lift_fe r = mul_pure (lift_fe a) (lift_fe_mont c)`. + + This is the sibling of `lift_fe_mont_mmfbf_pure_eq` for callers that + feed the Mont-multiplied lane into a subsequent `sub`/`negate` (which + consume `lift_fe`, not `lift_fe_mont`). The algebra reduces to the + ZMod q equation `r = a * (c * 169)`, again using the Montgomery + inversion identity `2^16 * 169 ≡ 1 (mod q)`. -/ +theorem lift_fe_mont_mul_pure_eq + (a c r : Std.I16) + (h : (r.val * (2 ^ 16 : Int)) % 3329 = (a.val * c.val) % 3329) : + lift_fe r + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe a) (lift_fe_mont c) := by + -- Both sides equal `feOfZMod ((a.val : ZMod q) * ((c.val : ZMod q) * 169))`. + -- LHS: `lift_fe r = feOfZMod ((r.val : ZMod q))`. + -- RHS: round-trip via Canonical_mul_pure + mul_pure_val_eq. + have h_lhs_zmod : (r.val : ZMod 3329) = (a.val : ZMod 3329) * ((c.val : ZMod 3329) * 169) := by + have h_modq : libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + (r.val * (2 ^ 16 : Int)) (a.val * c.val) 3329 := by + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + rw [Int.sub_emod, h]; simp + have h_zmod : ((r.val * (2 ^ 16 : Int) : Int) : ZMod 3329) + = ((a.val * c.val : Int) : ZMod 3329) := + modq_eq_cast_zmod _ _ h_modq + push_cast at h_zmod + -- h_zmod : (r.val : ZMod 3329) * 2285 = (a.val : ZMod 3329) * (c.val : ZMod 3329) + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + calc (r.val : ZMod 3329) + = (r.val : ZMod 3329) * ((2285 : ZMod 3329) * 169) := by rw [h_inv]; ring + _ = ((r.val : ZMod 3329) * 2285) * 169 := by ring + _ = ((a.val : ZMod 3329) * (c.val : ZMod 3329)) * 169 := by rw [h_zmod] + _ = (a.val : ZMod 3329) * ((c.val : ZMod 3329) * 169) := by ring + have h_lhs : lift_fe r + = feOfZMod ((a.val : ZMod 3329) * ((c.val : ZMod 3329) * 169)) := by + unfold lift_fe i16_to_spec_fe_plain + rw [h_lhs_zmod] + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe a) (lift_fe_mont c) with hs_def + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure + (lift_fe a) (lift_fe_mont c) + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + have h_zmod_s : zmodOfFE s = (a.val : ZMod 3329) * ((c.val : ZMod 3329) * 169) := by + unfold zmodOfFE + rw [mul_pure_val_eq] + rw [ZMod.natCast_mod] + push_cast + -- Goal: ((lift_fe a).val.val : ZMod 3329) * ((lift_fe_mont c).val.val : ZMod 3329) + -- = (a.val : ZMod 3329) * ((c.val : ZMod 3329) * 169) + rw [lift_fe_val_val a] + -- LHS: ((a.val : ZMod 3329)) * ((lift_fe_mont c).val.val : ZMod 3329) = ... + rw [ZMod.natCast_zmod_val] + -- Now need: (a.val : ZMod 3329) * ((lift_fe_mont c).val.val : ZMod 3329) + -- = (a.val : ZMod 3329) * ((c.val : ZMod 3329) * 169) + -- Suffices to show: ((lift_fe_mont c).val.val : ZMod 3329) = (c.val : ZMod 3329) * 169. + have h_mont_val : ((lift_fe_mont c).val.val : ZMod 3329) + = (c.val : ZMod 3329) * 169 := by + unfold lift_fe_mont i16_to_spec_fe_mont + -- (lift_fe_mont c).val.val = (feOfZMod ((c.val : ZMod 3329) * 169)).val.val. + -- Apply ZMod.natCast_zmod_val + zmodOfFE_feOfZMod. + have h_step : zmodOfFE (feOfZMod ((c.val : ZMod 3329) * 169)) + = (c.val : ZMod 3329) * 169 := zmodOfFE_feOfZMod _ + -- zmodOfFE x = (x.val.val : ZMod 3329) (by defn). + have h_unfold : zmodOfFE (feOfZMod ((c.val : ZMod 3329) * 169)) + = ((feOfZMod ((c.val : ZMod 3329) * 169)).val.val : ZMod 3329) := rfl + rw [h_unfold] at h_step + exact h_step + rw [h_mont_val] + rw [h_lhs, ← h_round_trip, h_zmod_s] + +/-- L1.4 — `montgomery_multiply_by_constant` on a chunk. + Each lane: `vec[i] · c / R`. The lift uses `lift_chunk_mont` on + the output (the result is in Mont domain). -/ +@[spec high] +theorem montgomery_multiply_by_constant_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (c : Std.I16) + (hvec : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 32767) + (hc : c.val.natAbs ≤ 1664) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_by_constant vec c + ⦃ ⇓ r => ⌜ lift_chunk_mont r + = Spec.chunk_montgomery_multiply_by_constant_pure + (lift_chunk vec) (lift_fe_mont c) ⌝ ⦄ := by + -- 1. Extract per-element legacy fact: |r[i]| ≤ 3328 ∧ (r[i]*2^16) ≡ vec[i]*c (mod 3329). + have h_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.montgomery_multiply_by_constant_spec vec c hc + obtain ⟨r0, h_eq, h_per⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + -- 2. Reduce array equality to list equality, then to per-index lift_fe_mont equality. + unfold lift_chunk_mont Spec.chunk_montgomery_multiply_by_constant_pure + apply Subtype.ext + show r0.elements.val.map lift_fe_mont + = (List.range 16).map (fun i => + Spec.montgomery_multiply_fe_by_fer_pure + ((Std.Array.make 16#usize (vec.elements.val.map lift_fe) + (by simp)).val[i]!) + (lift_fe_mont c)) + have h_r0_len : r0.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length r0 + apply List.ext_getElem + · simp [List.length_map, List.length_range, h_r0_len] + · intro i hi1 hi2 + have hi : i < 16 := by + have : i < (r0.elements.val.map lift_fe_mont).length := hi1 + simp [List.length_map, h_r0_len] at this; exact this + rw [List.getElem_map] + rw [List.getElem_map, List.getElem_range] + show lift_fe_mont r0.elements.val[i] + = Spec.montgomery_multiply_fe_by_fer_pure + ((vec.elements.val.map lift_fe)[i]!) (lift_fe_mont c) + have h_r0_get_eq : r0.elements.val[i] + = r0.elements.val[i]! := by + have hi_r0 : i < r0.elements.val.length := by rw [h_r0_len]; exact hi + rw [getElem!_pos r0.elements.val i hi_r0] + rw [h_r0_get_eq] + have h_vec_len : vec.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length vec + have h_map_vec : + (vec.elements.val.map lift_fe)[i]! = lift_fe (vec.elements.val[i]!) := by + have hi_vec : i < vec.elements.val.length := by rw [h_vec_len]; exact hi + rw [getElem!_pos (vec.elements.val.map lift_fe) i (by + simp [List.length_map, h_vec_len]; exact hi)] + rw [List.getElem_map] + rw [getElem!_pos vec.elements.val i hi_vec] + rw [h_map_vec] + obtain ⟨_h_bnd, h_mod⟩ := h_per i hi + exact lift_fe_mont_mmfbf_pure_eq + (vec.elements.val[i]!) c (r0.elements.val[i]!) + h_mod + +/-! ### L1.5 — `cond_subtract_3329` private loop machinery. + + The legacy `libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.cond_subtract_3329_spec` requires + `0 ≤ vec[i] < 2*3329` as a precondition (it's load-bearing for the + OUTER bound `r[i] < 3329`). The FC statement here uses NO + precondition — we only need `lift_chunk r = lift_chunk vec`, i.e. + mod-3329 equivalence per lane. The mod-3329 equivalence holds for + BOTH branches of the conditional WITHOUT any precondition: + + - `vec[i] ≥ 3329` branch: `r[i] = wrapping_sub vec[i] 3329`. + Since `vec[i] ∈ [3329, 32767]` (signed), `vec[i] - 3329 ∈ [0, 29438]` + fits I16, so `r[i].val = vec[i].val - 3329 ≡ vec[i].val (mod 3329)`. + - `vec[i] < 3329` branch: `r[i] = vec[i]`, trivially mod-3329 equivalent. + + Below we reproduce a stripped-down copy of the CondSubtract3329 loop machinery + (private to FCTargets) yielding just the per-element disjunction. The + full proof closely mirrors `Equivalence.CondSubtract3329.cond_step`; comments are + abbreviated since the structure is verbatim. +-/ + +namespace CondSubtract3329FC + +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper Aeneas.Std Std.Do Result ControlFlow + +theorem triple_of_ok_l1 + {α : Type} {x : Result α} {v : α} {P : α → Prop} + (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +theorem of_pure_prop_holds_l1 {P : Prop} + (h : (pure P : Result Prop).holds) : P := by + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, + PredTrans.apply] at h + exact h trivial + +theorem pure_prop_holds_l1 {P : Prop} (h : P) : (pure P : Result Prop).holds := by + simp only [Aeneas.Std.Result.holds, Std.Do.Triple, Std.Do.WP.wp, PostCond.noThrow, + PredTrans.apply] + intro _; exact h + +/-- Per-element invariant for `cond_subtract_3329` (FCTargets-local copy + of `Equivalence.CondSubtract3329.cond_inv`; precondition-free). -/ +def cond_inv + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector → + Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + (((input.elements.val[j]!).val ≥ 3329 ∧ + (acc.elements.val[j]!) = Std.I16.wrapping_sub (input.elements.val[j]!) 3329#i16) + ∨ ((input.elements.val[j]!).val < 3329 ∧ + acc.elements.val[j]! = input.elements.val[j]!))) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.elements.val[j]! = input.elements.val[j]!)) + +def cond_step_post + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (cond_inv input iter'.start acc').holds + | .done y => (cond_inv input 16#usize y).holds + +set_option maxHeartbeats 8000000 in +theorem cond_step + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (acc : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (cond_inv input k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329_loop.body + { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ cond_step_post input k r ⌝ ⦄ := by + obtain ⟨h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_l1 h_inv + have h_acc_len : acc.elements.length = 16 := PortableVector_elements_length acc + have h_16 : (16#usize : Std.Usize).val = 16 := rfl + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329_loop.body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · have hk_16 : k.val < 16 := by rw [h_16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_some_eq k h_lt + have h_idx : + Aeneas.Std.Array.index_usize acc.elements k = .ok (acc.elements.val[k.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq acc.elements k (by rw [h_acc_len]; exact hk_16) + set xk : Std.I16 := acc.elements.val[k.val]! with hxk_def + have h_acc_xk : acc.elements.val[k.val]! = input.elements.val[k.val]! := + h_acc_undone k.val (Nat.le_refl _) hk_16 + by_cases h_ge : xk.val ≥ 3329 + · have h_ge_lit : xk ≥ 3329#i16 := by + change (3329#i16 : Std.I16).val ≤ xk.val + have : (3329#i16 : Std.I16).val = 3329 := by decide + rw [this]; exact h_ge + have h_wsub : + CoreModels.core.num.I16.wrapping_sub xk 3329#i16 + = .ok (Std.I16.wrapping_sub xk 3329#i16) := by + unfold CoreModels.core.num.I16.wrapping_sub + unfold rust_primitives.arithmetic.wrapping_sub_i16 + rfl + have h_upd : + Aeneas.Std.Array.update acc.elements k (Std.I16.wrapping_sub xk 3329#i16) + = .ok (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16)) := + array_update_ok_eq acc.elements k _ (by rw [h_acc_len]; exact hk_16) + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let i2 ← libcrux_secrets.traits.Declassify.Blanket.declassify i1 + if i2 >= 3329#i16 + then + let i3 ← CoreModels.core.num.I16.wrapping_sub i1 3329#i16 + let a ← Aeneas.Std.Array.update acc.elements i i3 + ok (cont (iter1, { elements := a })) + else ok (cont (iter1, acc))) + = .ok (cont + (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16) })) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [bind_tc_ok] + rw [h_idx] + simp only [bind_tc_ok] + rw [show libcrux_secrets.traits.Declassify.Blanket.declassify xk = .ok xk from rfl] + simp only [bind_tc_ok] + rw [if_pos h_ge_lit] + rw [h_wsub] + simp only [bind_tc_ok] + rw [h_upd] + rfl + apply triple_of_ok_l1 h_body + show cond_step_post input k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16) })) + unfold cond_step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l1 + refine ⟨?_, ?_⟩ + · intro j hj + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16))[j]! + = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j _ h_ne + have h_set_eq_val : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16)).val[j]! + = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + have h_old := h_acc_done j hj_lt_k + rcases h_old with ⟨h_in_ge, h_acc_eq⟩ | ⟨h_in_lt, h_acc_eq⟩ + · left; refine ⟨h_in_ge, ?_⟩; rw [h_set_eq_val]; exact h_acc_eq + · right; refine ⟨h_in_lt, ?_⟩; rw [h_set_eq_val]; exact h_acc_eq + · subst hj_eq_k + have h_lt'' : k.val < acc.elements.length := by rw [h_acc_len]; exact hk_16 + have h_set_eq : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16))[k.val]! + = Std.I16.wrapping_sub xk 3329#i16 := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.elements k k.val _ ⟨rfl, h_lt''⟩ + have h_set_eq_val : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16)).val[k.val]! + = Std.I16.wrapping_sub xk 3329#i16 := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + left + refine ⟨?_, ?_⟩ + · rw [← h_acc_xk]; exact h_ge + · rw [h_set_eq_val, ← h_acc_xk] + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16))[j]! + = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j _ h_ne + have h_set_eq_val : + (acc.elements.set k (Std.I16.wrapping_sub xk 3329#i16)).val[j]! + = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + rw [h_set_eq_val] + exact h_acc_undone j h_ge' hj_lt + · have h_not_ge : ¬ (3329#i16 : Std.I16).val ≤ xk.val := by + have h_eq : (3329#i16 : Std.I16).val = 3329 := by decide + rw [h_eq]; exact h_ge + have h_not_ge' : ¬ (xk ≥ 3329#i16) := h_not_ge + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let i2 ← libcrux_secrets.traits.Declassify.Blanket.declassify i1 + if i2 >= 3329#i16 + then + let i3 ← CoreModels.core.num.I16.wrapping_sub i1 3329#i16 + let a ← Aeneas.Std.Array.update acc.elements i i3 + ok (cont (iter1, { elements := a })) + else ok (cont (iter1, acc))) + = .ok (cont + (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc)) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [bind_tc_ok] + rw [h_idx] + simp only [bind_tc_ok] + rw [show libcrux_secrets.traits.Declassify.Blanket.declassify xk = .ok xk from rfl] + simp only [bind_tc_ok] + rw [if_neg h_not_ge'] + apply triple_of_ok_l1 h_body + show cond_step_post input k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + acc)) + unfold cond_step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + apply pure_prop_holds_l1 + refine ⟨?_, ?_⟩ + · intro j hj + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · exact h_acc_done j hj_lt_k + · subst hj_eq_k + right + refine ⟨?_, ?_⟩ + · rw [← h_acc_xk]; show xk.val < 3329 + push Not at h_ge; exact h_ge + · exact h_acc_xk + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ge' : k.val ≤ j := by omega + exact h_acc_undone j h_ge' hj_lt + · have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h_16] at hk_ge; omega + have h_iter_none := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.iter_next_none_eq k hk_ge + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let i2 ← libcrux_secrets.traits.Declassify.Blanket.declassify i1 + if i2 >= 3329#i16 + then + let i3 ← CoreModels.core.num.I16.wrapping_sub i1 3329#i16 + let a ← Aeneas.Std.Array.update acc.elements i i3 + ok (cont (iter1, { elements := a })) + else ok (cont (iter1, acc))) + = .ok (done acc) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_l1 h_body + show cond_step_post input k (.done acc) + unfold cond_step_post + apply pure_prop_holds_l1 + refine ⟨?_, ?_⟩ + · intro j hj + apply h_acc_done j + rw [hk_eq]; rw [h_16] at hj; exact hj + · intro j hj_ge hj_lt + apply h_acc_undone j _ hj_lt + rw [hk_eq]; rw [h_16] at hj_ge; exact hj_ge + +end CondSubtract3329FC + +/-- L1.5 — `cond_subtract_3329` on a chunk. + NO HACSPEC EQUIVALENT — the impl conditionally subtracts q to + rebalance ranges; spec-side this is identity in `ZMod 3329`. The + spec target we land against is the identity at the FE-array level. + + No precondition required: the mod-3329 equivalence holds in BOTH + branches of the conditional unconditionally (see `CondSubtract3329FC` namespace + above for the precondition-free invariant). -/ +@[spec high] +theorem cond_subtract_3329_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329 vec + ⦃ ⇓ r => ⌜ lift_chunk r = lift_chunk vec ⌝ ⦄ := by + -- 1. Run the loop-spec machinery to get the per-element disjunction. + have h_disj : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329 vec + ⦃ ⇓ r => ⌜ ∀ j : Nat, j < 16 → + ((vec.elements.val[j]!).val ≥ 3329 ∧ + (r.elements.val[j]!) + = Std.I16.wrapping_sub (vec.elements.val[j]!) 3329#i16) + ∨ ((vec.elements.val[j]!).val < 3329 ∧ + r.elements.val[j]! = vec.elements.val[j]!) ⌝ ⦄ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329 + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329_loop + have h_field : libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR + = (16#usize : Std.Usize) := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + rw [h_field] + apply Std.Do.Triple.of_entails_right _ + (libcrux_iot_ml_kem.Util.LoopSpecs.loop_range_spec_usize + (fun (iter1, vec1) => + libcrux_iot_ml_kem.vector.portable.arithmetic.cond_subtract_3329_loop.body + iter1 vec1) + vec 0#usize 16#usize + (CondSubtract3329FC.cond_inv vec) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (CondSubtract3329FC.pure_prop_holds_l1 ⟨ + fun j hj => by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0] at hj; exact absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_done, _h_undone⟩ := CondSubtract3329FC.of_pure_prop_holds_l1 h + intro j hj + exact h_done j (by rw [show (16#usize : Std.Usize).val = 16 from rfl]; exact hj) + · -- Step lemma. + intro acc k h_ge h_le hinv + have h_step := CondSubtract3329FC.cond_step vec acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : CondSubtract3329FC.cond_step_post vec k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [CondSubtract3329FC.cond_step_post] using hP + · have hP : CondSubtract3329FC.cond_step_post vec k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [CondSubtract3329FC.cond_step_post] using hP + -- 2. Apply h_disj and convert per-lane disjunction to lift_fe equality. + obtain ⟨r0, h_eq, h_per⟩ := triple_exists_ok_fc h_disj + apply triple_of_ok_fc (v := r0) h_eq + -- 3. Reduce array equality to list equality, then to per-index lift_fe equality. + unfold lift_chunk + apply Subtype.ext + show r0.elements.val.map lift_fe = vec.elements.val.map lift_fe + have h_r0_len : r0.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length r0 + have h_vec_len : vec.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length vec + apply List.ext_getElem + · simp [List.length_map, h_r0_len, h_vec_len] + · intro i hi1 hi2 + have hi : i < 16 := by + have : i < (r0.elements.val.map lift_fe).length := hi1 + simp [List.length_map, h_r0_len] at this; exact this + rw [List.getElem_map, List.getElem_map] + -- Goal: lift_fe r0.elements.val[i] = lift_fe vec.elements.val[i]. + have h_r0_get : r0.elements.val[i] = r0.elements.val[i]! := by + have hi_r0 : i < r0.elements.val.length := by rw [h_r0_len]; exact hi + rw [getElem!_pos r0.elements.val i hi_r0] + have h_vec_get : vec.elements.val[i] = vec.elements.val[i]! := by + have hi_vec : i < vec.elements.val.length := by rw [h_vec_len]; exact hi + rw [getElem!_pos vec.elements.val i hi_vec] + rw [h_r0_get, h_vec_get] + -- Apply per-lane disjunction. + rcases h_per i hi with ⟨h_ge, h_eq_lane⟩ | ⟨h_lt, h_eq_lane⟩ + · -- ≥ 3329 branch: r0[i] = wrapping_sub vec[i] 3329, derive mod-3329 equality. + rw [h_eq_lane] + -- Goal: lift_fe (wrapping_sub vec[i] 3329) = lift_fe vec[i]. + set xi : Std.I16 := vec.elements.val[i]! with hxi + -- Use modq_eq on (wrapping_sub xi 3329).val vs xi.val. + apply lift_fe_eq_of_modq + -- Need: modq_eq (wrapping_sub xi 3329).val xi.val 3329. + -- (wrapping_sub xi 3329).val = bmod (xi.val - 3329) (2^16). Since + -- xi.val ≥ 3329 and xi.val < 2^15, we have xi.val - 3329 ∈ [0, 2^15 - 3329], + -- which is in I16 range, so bmod = xi.val - 3329. + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + rw [Std.I16.wrapping_sub_val_eq] + have hxi_ub : xi.val < 2^15 := by + have h := xi.hBounds + simp [Aeneas.Std.IScalarTy.numBits] at h + omega + have h3329 : (3329#i16 : Std.I16).val = 3329 := by decide + rw [h3329] + have hxi_lb_diff : (-(2:Int)^(16-1)) ≤ xi.val - 3329 := by + have h1 : (3329 : Int) ≤ xi.val := h_ge + have h2 : -(2:Int)^(16-1) ≤ 0 := by decide + have h3 : (0 : Int) ≤ xi.val - 3329 := by omega + omega + have hxi_ub_diff : xi.val - 3329 < (2:Int)^(16-1) := by + have h1 : xi.val < (2:Int)^15 := by exact_mod_cast hxi_ub + have h2 : (2:Int)^(16-1) = (2:Int)^15 := by decide + omega + rw [Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + hxi_lb_diff hxi_ub_diff] + -- Goal: (xi.val - 3329 - xi.val) % 3329 = 0. + have : xi.val - 3329 - xi.val = -3329 := by ring + rw [this] + decide + · -- < 3329 branch: r0[i] = vec[i], trivially mod-3329 equivalent. + rw [h_eq_lane] + +/-- Local helper (mirrors `Spec.Pure.uscalar_rem_ok_U16` which is + file-private). Establishes that U16 modular remainder by a non-zero + divisor is always `.ok`, and exposes the underlying value. Needed + by `neg_pure_val_eq`, whose `% q` step is at U16 width (no widening). -/ +theorem uscalar_rem_ok_U16_local (z m : Std.U16) (hm : m.val ≠ 0) : + ∃ w : Std.U16, (z % m : Result Std.U16) = .ok w ∧ w.val = z.val % m.val := by + have heq : (z % m : Result Std.U16) = Std.UScalar.rem z m := rfl + unfold Std.UScalar.rem at heq + simp [hm] at heq + refine ⟨_, heq, ?_⟩ + show (BitVec.umod z.bv m.bv).toNat = z.val % m.val + unfold BitVec.umod + simp only [BitVec.toNat_ofNatLT] + rfl + +/-- Bridge lemma: under canonicity of the operand, the `.val.val` of + `FieldElement.neg_pure` is the impl's U16 modular-reduced negation + `(3329 - a.val.val) % 3329`. Mirrors `sub_pure_val_eq`'s trace, but + the impl's `neg` body is `(q - self.val) % q` operated entirely at + U16 width (NO widening to U32); panic-impossibility of `q - self.val` + is precisely `Canonical a` (i.e. `a.val.val < q`). -/ +theorem neg_pure_val_eq + (a : hacspec_ml_kem.parameters.FieldElement) + (ha : libcrux_iot_ml_kem.Spec.Pure.Canonical a) : + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure a).val.val + = (3329 - a.val.val) % 3329 := by + have hneg : + hacspec_ml_kem.parameters.FieldElement.neg a + = .ok (libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure a) := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_eq_ok a ha + have ha' : a.val.val < 3329 := by + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at ha + unfold hacspec_ml_kem.parameters.FIELD_MODULUS at ha; simpa using ha + unfold hacspec_ml_kem.parameters.FieldElement.neg at hneg + have hA := a.val.hBounds + simp [Aeneas.Std.UScalarTy.numBits] at hA + have hqval : (hacspec_ml_kem.parameters.FIELD_MODULUS : Std.U16).val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; simp + have hae := Std.UScalar.sub_equiv (hacspec_ml_kem.parameters.FIELD_MODULUS : Std.U16) a.val + cases hqa : + ((hacspec_ml_kem.parameters.FIELD_MODULUS : Std.U16) - a.val : Result Std.U16) with + | ok i => + rw [hqa] at hae hneg; simp at hae + obtain ⟨_hale, hival, _⟩ := hae + simp only [Aeneas.Std.bind_tc_ok] at hneg + have hq_ne : (hacspec_ml_kem.parameters.FIELD_MODULUS : Std.U16).val ≠ 0 := by + rw [hqval]; decide + obtain ⟨w, hw_eq, hwval⟩ := + uscalar_rem_ok_U16_local i hacspec_ml_kem.parameters.FIELD_MODULUS hq_ne + rw [hw_eq] at hneg; simp only [Aeneas.Std.bind_tc_ok] at hneg + unfold hacspec_ml_kem.parameters.FieldElement.new at hneg + simp at hneg + rw [← hneg] + -- Goal: w.val = (3329 - a.val.val) % 3329. + rw [hwval, hqval] + -- Goal: i.val % 3329 = (3329 - a.val.val) % 3329. + -- From hival : i.val + a.val.val = 3329, so i.val = 3329 - a.val.val. + have hi_eq : i.val = 3329 - a.val.val := by + rw [hqval] at hival; omega + rw [hi_eq] + | fail e => + rw [hqa] at hae; simp at hae + rw [hqval] at hae + omega + | div => rw [hqa] at hae; exact hae.elim + +/-- Bridge lemma: under the no-overflow bound on `-a.val` (i.e. + `a.val.natAbs ≤ 2^15 - 1`, equivalently `a.val ∈ [-(2^15 - 1), 2^15 - 1]`, + which EXCLUDES the boundary `-2^15`), any `r : Std.I16` carrying that + negation lifts to `FieldElement.neg_pure (lift_fe a)`. + + Pure-projection content: both sides reduce to + `feOfZMod ((-a.val : Int) : ZMod 3329)`. The LHS is direct from + `lift_fe`'s definition. The RHS uses `neg_pure_val_eq` plus canonical + round-trip — the result is canonical by `Canonical_neg_pure` (which + needs `Canonical_lift_fe`), and the `zmodOfFE`-projection reduces + the inner `(3329 - (lift_fe a).val.val) % 3329` to `((-a.val : Int) + : ZMod 3329)` via `ZMod.natCast_mod` plus integer reasoning. + + Boundary excluded: at `a.val = -2^15 = -32768`, both sides would + diverge — `Int.bmod (-a.val) 2^16 = -32768`, but + `(-((-32768 : Int) : ZMod 3329)).val = 2807 ≠ 522`. The `hbnd` + precondition rules out this case. -/ +theorem lift_fe_neg_pure_eq + (a r : Std.I16) + (hbnd : a.val.natAbs ≤ 2^15 - 1) + (hrv : r.val = -a.val) : + lift_fe r + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe a) := by + -- LHS reduction. + have h_lhs : lift_fe r = feOfZMod (((-a.val : Int)) : ZMod 3329) := by + unfold lift_fe i16_to_spec_fe_plain + rw [hrv] + -- RHS reduction. + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure (lift_fe a) with hs_def + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_neg_pure + (lift_fe a) (Canonical_lift_fe a) + show s.val.val < 3329 + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + have h_zmod_s : zmodOfFE s = (((-a.val : Int)) : ZMod 3329) := by + unfold zmodOfFE + rw [neg_pure_val_eq _ (Canonical_lift_fe a)] + -- Goal: ((3329 - (lift_fe a).val.val) % 3329 : Nat) : ZMod 3329 + -- = ((-a.val : Int) : ZMod 3329) + rw [ZMod.natCast_mod] + -- (lift_fe a).val.val < 3329, so 3329 - (lift_fe a).val.val ≤ 3329. + have ha_lt : (lift_fe a).val.val < 3329 := by + have h_ca := Canonical_lift_fe a + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_ca + unfold hacspec_ml_kem.parameters.FIELD_MODULUS at h_ca; simpa using h_ca + -- Cast the Nat-sub through Int-sub: 3329 - (lift_fe a).val.val (Nat) = + -- 3329 - (lift_fe a).val.val (Int) since the former is ≥ 0. + have h_zmod_eq : + (((3329 - (lift_fe a).val.val : Nat)) : ZMod 3329) + = ((((3329 : Int) - ((lift_fe a).val.val : Int)) : Int) : ZMod 3329) := by + have h_int_eq : + (((3329 - (lift_fe a).val.val : Nat)) : Int) + = (3329 : Int) - ((lift_fe a).val.val : Int) := by + omega + have h_route : + (((3329 - (lift_fe a).val.val : Nat)) : ZMod 3329) + = ((((3329 - (lift_fe a).val.val : Nat)) : Int) : ZMod 3329) := by + rfl + rw [h_route, h_int_eq] + rw [h_zmod_eq] + push_cast + rw [lift_fe_val_val a] + rw [ZMod.natCast_zmod_val] + -- After push_cast: 0 - ((a.val : Int) : ZMod 3329) = ((-a.val : Int) : ZMod 3329) + -- (3329 collapses to 0 via ZMod.natCast_self). + ring + rw [h_lhs, ← h_round_trip, h_zmod_s] + +/-- L1.6 — `negate` on a chunk. + + **Precondition** `hpre` mirrors the upstream F* spec `negate_pre` + from `libcrux-ml-kem-proofs/libcrux-ml-kem/src/vector/traits.rs:684` + (`forall i. is_intb (pow2 15 - 1) (v ${vec}[i])`), i.e. every lane + is strictly within `[-(2^15 - 1), 2^15 - 1]` — equivalently the + natAbs is `≤ 2^15 - 1`. + + **Why this is the canonical bound**: the impl's `negate` is + pointwise `core.num.I16.wrapping_neg`, which lowers to + `wrapping_sub 0 vec[i]`. The lane-level value reduces to + `Int.bmod (-vec[i].val) 2^16`. For the FC equation + `(r[i].val : ZMod 3329) = -(vec[i].val : ZMod 3329)` to hold we + need `Int.bmod (-vec[i].val) 2^16 = -vec[i].val` (no boundary flip); + `bmod_pow2_eq_of_inBounds'` requires `-vec[i].val ∈ [-2^15, 2^15)`, + i.e. `vec[i].val ∈ (-2^15, 2^15]`. Combined with the impl's + `vec[i].val ∈ [-2^15, 2^15)` carrier we get `vec[i].val.natAbs + ≤ 2^15 - 1` — exactly `negate_pre`. The excluded value `-2^15` + would yield a real divergence: `2^16 mod 3329 = 2645 ≠ 0`, so + bmod's two-valued identification of `-2^15` and `2^15` does NOT + collapse mod 3329. + + **Callers**: every real caller of `negate` (in + `serialize::compress_then_serialize_*`) feeds inputs that are + barrett-reduced (so `|x| ≤ 1664 < 2^15 - 1`) or subtracted from + barrett-reduced operands (so `|x| ≤ 6656 < 2^15 - 1`); the bound + is trivially satisfied at every call site. -/ +@[spec high] +theorem negate_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hpre : ∀ i : Nat, i < 16 → + (vec.elements.val[i]!).val.natAbs ≤ 2^15 - 1) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.negate vec + ⦃ ⇓ r => ⌜ lift_chunk r = Spec.chunk_neg_pure (lift_chunk vec) ⌝ ⦄ := by + -- 1. Extract per-element BV-equation from legacy `negate_spec`. + have h_legacy := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.negate_spec vec + obtain ⟨r0, h_eq, h_per⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + -- 2. Reduce array equality to list equality, then to per-index lift_fe equality. + unfold lift_chunk Spec.chunk_neg_pure + apply Subtype.ext + show r0.elements.val.map lift_fe + = (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure + ((Std.Array.make 16#usize (vec.elements.val.map lift_fe) + (by simp)).val[i]!)) + have h_r0_len : r0.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length r0 + apply List.ext_getElem + · simp [List.length_map, List.length_range, h_r0_len] + · intro i hi1 hi2 + have hi : i < 16 := by + have : i < (r0.elements.val.map lift_fe).length := hi1 + simp [List.length_map, h_r0_len] at this; exact this + rw [List.getElem_map] + rw [List.getElem_map, List.getElem_range] + show lift_fe r0.elements.val[i] + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.neg_pure + ((vec.elements.val.map lift_fe)[i]!) + have h_r0_get_eq : r0.elements.val[i] + = r0.elements.val[i]! := by + have hi_r0 : i < r0.elements.val.length := by rw [h_r0_len]; exact hi + rw [getElem!_pos r0.elements.val i hi_r0] + rw [h_r0_get_eq] + have h_vec_len : vec.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length vec + have h_map_vec : + (vec.elements.val.map lift_fe)[i]! = lift_fe (vec.elements.val[i]!) := by + have hi_vec : i < vec.elements.val.length := by rw [h_vec_len]; exact hi + rw [getElem!_pos (vec.elements.val.map lift_fe) i (by + simp [List.length_map, h_vec_len]; exact hi)] + rw [List.getElem_map] + rw [getElem!_pos vec.elements.val i hi_vec] + rw [h_map_vec] + -- 3. Convert per-lane BV-equation to val-equation, then apply bridge. + set xi : Std.I16 := vec.elements.val[i]! with hxi + set ri : Std.I16 := r0.elements.val[i]! with hri + -- From negate_spec: ri.bv = -xi.bv. + have h_bv : ri.bv = -xi.bv := h_per i hi + -- From this BV equality + I16.bv_toInt_eq: ri.val = (-xi.bv).toInt. + -- The cleanest route: bridge through `Std.I16.wrapping_sub 0 xi`, + -- whose `.bv = 0 - xi.bv = -xi.bv` matches `ri.bv`, then use + -- `wrapping_sub_val_eq` to get `Int.bmod (0 - xi.val) (2^16)`. + have h_wsub_bv : (Aeneas.Std.I16.wrapping_sub (0#i16) xi).bv = -xi.bv := by + rw [Aeneas.Std.I16.wrapping_sub_bv_eq] + simp only [show (0#i16 : Std.I16).bv = (0 : BitVec 16) from rfl] + exact BitVec.zero_sub xi.bv + -- Convert BV equality to val equality: ri.val = (-xi.bv).toInt = + -- (Std.I16.wrapping_sub 0 xi).val = Int.bmod (0 - xi.val) (2^16) = -xi.val + -- (last step under hpre via bmod_pow2_eq_of_inBounds'). + have h_ri_val : ri.val = -xi.val := by + have h_step1 : ri.val = (Aeneas.Std.I16.wrapping_sub (0#i16) xi).val := by + have h_toInt : + (ri.bv).toInt = (Aeneas.Std.I16.wrapping_sub (0#i16) xi).bv.toInt := by + rw [h_bv, h_wsub_bv] + have h_lhs : (ri.bv).toInt = ri.val := Aeneas.Std.I16.bv_toInt_eq ri + have h_rhs : (Aeneas.Std.I16.wrapping_sub (0#i16) xi).bv.toInt + = (Aeneas.Std.I16.wrapping_sub (0#i16) xi).val := + Aeneas.Std.I16.bv_toInt_eq _ + rw [h_lhs, h_rhs] at h_toInt + exact h_toInt + rw [h_step1] + rw [Aeneas.Std.I16.wrapping_sub_val_eq] + have h0 : (0#i16 : Std.I16).val = 0 := by decide + rw [h0] + -- Goal: Int.bmod (0 - xi.val) (2^16) = -xi.val. + have h_diff : (0 : Int) - xi.val = -xi.val := by ring + rw [h_diff] + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_abs : xi.val.natAbs ≤ 2^15 - 1 := hpre i hi + have h_pow : -((2 : Int) ^ (16 - 1)) = -(2^15 : Int) := by decide + rw [h_pow] + omega + · have h_abs : xi.val.natAbs ≤ 2^15 - 1 := hpre i hi + have h_pow : ((2 : Int) ^ (16 - 1)) = (2^15 : Int) := by decide + rw [h_pow] + omega + -- 4. Apply the bridge lemma. + have h_abs : xi.val.natAbs ≤ 2^15 - 1 := hpre i hi + exact lift_fe_neg_pure_eq xi ri h_abs h_ri_val + +/-- L1.7 — `multiply_by_constant` (plain) on a chunk. + + **Precondition note**: the legacy `libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.multiply_by_constant_spec` + requires the per-element product bound `|vec[i] * c| ≤ 2^15 - 1`. + The aggregate `|vec[i]| ≤ 32767 ∧ |c| ≤ 1664` does NOT imply that + product bound (it allows `32767 * 1664 ≫ 32767`), so we carry + `hpre_prod` as an additional caller obligation. Callers downstream + in the NTT pipeline reliably satisfy this — Mont-domain inputs are + already `|vec[i]| ≤ 3328 + 1665` after a `montgomery_reduce`, and + the product with `|c| ≤ 1664` is well inside i32 with the per-lane + bound easily verified. The `hvec` / `hc` are kept for API + consistency with `montgomery_multiply_by_constant_fc`. -/ +@[spec high] +theorem multiply_by_constant_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (c : Std.I16) + (hvec : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 32767) + (hc : c.val.natAbs ≤ 1664) + (hpre_prod : ∀ i : Nat, i < 16 → + ((vec.elements.val[i]!).val * c.val : Int).natAbs ≤ 2^15 - 1) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.multiply_by_constant vec c + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_multiply_by_constant_pure + (lift_chunk vec) (lift_fe c) ⌝ ⦄ := by + -- 1. Extract per-element value-equation from legacy bounds Triple. + have h_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.multiply_by_constant_spec vec c hpre_prod + obtain ⟨r0, h_eq, h_per⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + -- 2. Reduce array equality to list equality, then to per-index lift_fe equality. + unfold lift_chunk Spec.chunk_multiply_by_constant_pure + apply Subtype.ext + show r0.elements.val.map lift_fe + = (List.range 16).map (fun i => + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((Std.Array.make 16#usize (vec.elements.val.map lift_fe) + (by simp)).val[i]!) + (lift_fe c)) + have h_r0_len : r0.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length r0 + apply List.ext_getElem + · simp [List.length_map, List.length_range, h_r0_len] + · intro i hi1 hi2 + have hi : i < 16 := by + have : i < (r0.elements.val.map lift_fe).length := hi1 + simp [List.length_map, h_r0_len] at this; exact this + rw [List.getElem_map] + rw [List.getElem_map, List.getElem_range] + show lift_fe r0.elements.val[i] + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((vec.elements.val.map lift_fe)[i]!) (lift_fe c) + have h_r0_get_eq : r0.elements.val[i] + = r0.elements.val[i]! := by + have hi_r0 : i < r0.elements.val.length := by rw [h_r0_len]; exact hi + rw [getElem!_pos r0.elements.val i hi_r0] + rw [h_r0_get_eq] + have h_vec_len : vec.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length vec + have h_map_vec : + (vec.elements.val.map lift_fe)[i]! = lift_fe (vec.elements.val[i]!) := by + have hi_vec : i < vec.elements.val.length := by rw [h_vec_len]; exact hi + rw [getElem!_pos (vec.elements.val.map lift_fe) i (by + simp [List.length_map, h_vec_len]; exact hi)] + rw [List.getElem_map] + rw [getElem!_pos vec.elements.val i hi_vec] + rw [h_map_vec] + obtain ⟨h_val, _h_bnd⟩ := h_per i hi + exact lift_fe_mul_pure_eq + (vec.elements.val[i]!) c (r0.elements.val[i]!) + h_val + +/-- L1.8 — `bitwise_and_with_constant` on a chunk. + + the upstream F* spec + `libcrux-ml-kem-proofs/libcrux-ml-kem/src/vector/traits.rs:720` + (`bitwise_and_with_constant_constant_post`), the canonical FC + statement for this bit-level op is the per-lane BV-equality + `result == map_array (fun x -> x &. c) vec`. This is the + "FE-level lift" formulation is NOT meaningful here because + `lift_fe` projects through `mod 3329`, discarding the bit pattern + that `&.` depends on (the FE equation would require lift_chunk to + preserve bit info, which it cannot without losing the ring + semantics). The canonical FC equation is therefore at the I16-BV + level — equality-form, equality-strong, just at the correct + abstraction layer for a bit-level op. + + The §0.5 `Spec.chunk_bitwise_and_with_constant_pure` stub remains + in place as documentation but is NOT used in the FC equation + below (the canonical post is the upstream F*-aligned BV-equality). -/ +@[spec high] +theorem bitwise_and_with_constant_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (c : Std.I16) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.bitwise_and_with_constant vec c + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).bv = (vec.elements.val[i]!).bv &&& c.bv ⌝ ⦄ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.bitwise_and_with_constant_spec vec c + +/-- L1.9 — `shift_right` on a chunk. + + the upstream F* spec + `libcrux-ml-kem-proofs/libcrux-ml-kem/src/vector/traits.rs:731` + (`shift_right_post`), the canonical FC statement for this bit-level + op is the per-lane BV-equality + `result == map_array (fun x -> x >>! shift_by) vec` (where `>>!` + is signed right shift on i16). Same reasoning as `bitwise_and_with_constant_fc`: + the I16-BV level is the correct abstraction; lift_fe would discard + the bit pattern that `>>!` depends on. + + The legacy `libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.shift_right_spec` uses `0 ≤ SHIFT_BY.val ∧ + SHIFT_BY.val < 16` (the same range as the upstream F* `requires + SHIFT_BY >= 0 && SHIFT_BY < 16` on the trait). We adopt the same + precondition shape. -/ +@[spec high] +theorem shift_right_fc + (SHIFT_BY : Std.I32) + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hs : 0 ≤ SHIFT_BY.val ∧ SHIFT_BY.val < 16) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.shift_right SHIFT_BY vec + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + (r.elements.val[i]!).bv = + (vec.elements.val[i]!).bv.sshiftRight SHIFT_BY.val.toNat ⌝ ⦄ := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.shift_right_spec SHIFT_BY hs vec + +/-- Per-element bridge for `reducing_from_i32_array_fc`: from the legacy + L1.10 congruence `(r * 2^16) ≡ x (mod 3329)`, derive the FC equation + `lift_fe_mont r = Spec.mont_reduce_pure (lift_fe_int x.val)`. + + Algebra: the goal (after unfolding via `mont_reduce_pure_lift_fe_int` + and `lift_fe_mont`/`i16_to_spec_fe_mont`) is the `ZMod 3329` equation + `r * 169 = x * 169 * 169`. From the legacy hypothesis `r * 2^16 = x` + in `ZMod 3329`, multiply both sides by `169 * 169` and use the + Montgomery-inversion identity `2^16 * 169 ≡ 1 (mod 3329)` + (numerically `2285 * 169 = 1` in `ZMod 3329`) to collapse one factor + on the LHS. -/ +theorem lift_fe_mont_mont_reduce_pure_eq + (x : Std.I32) (r : Std.I16) + (h : libcrux_iot_ml_kem.Spec.ModularArith.modq_eq + (r.val * (2 ^ 16 : Int)) x.val 3329) : + lift_fe_mont r = Spec.mont_reduce_pure (lift_fe_int x.val) := by + rw [mont_reduce_pure_lift_fe_int] + unfold lift_fe_mont i16_to_spec_fe_mont + congr 1 + -- Goal: (r.val : ZMod 3329) * 169 = (x.val : ZMod 3329) * 169 * 169 + have h_zmod : ((r.val * (2 ^ 16 : Int) : Int) : ZMod 3329) + = ((x.val : Int) : ZMod 3329) := + modq_eq_cast_zmod _ _ h + push_cast at h_zmod + -- h_zmod : (r.val : ZMod 3329) * 2285 = (x.val : ZMod 3329) + -- Goal: (r.val : ZMod 3329) * 169 = (x.val : ZMod 3329) * 169 * 169 + have h_inv : ((2285 : ZMod 3329)) * 169 = 1 := by decide + calc (r.val : ZMod 3329) * 169 + = (r.val : ZMod 3329) * ((2285 : ZMod 3329) * 169) * 169 := by rw [h_inv]; ring + _ = ((r.val : ZMod 3329) * 2285) * 169 * 169 := by ring + _ = (x.val : ZMod 3329) * 169 * 169 := by rw [h_zmod] + +/-- L1.10 — `reducing_from_i32_array` on a chunk. + Composes `montgomery_reduce_element` across 16 lanes. + + POST additionally exposes the per-lane I16 bound `|r[i]| ≤ 4993` + (= 3328 + 1665) coming from `libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.reducing_from_i32_array_spec`. + Used by L6.7 to thread a bound through to L7.1 Stage 3, where + `add_standard_error_reduce_fc` consumes `|self[k][ℓ]| ≤ 32767` + via `4993 ≤ 32767`. -/ +@[spec high] +theorem reducing_from_i32_array_fc + (array : Slice Std.I32) + (out : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (hlen : array.length = 16) + (hbound : ∀ i : Nat, i < 16 → + (array.val[i]!).val.natAbs ≤ 2^16 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.reducing_from_i32_array array out + ⦃ ⇓ r => ⌜ lift_chunk_mont r = Spec.chunk_reducing_from_i32_array_pure array + ∧ (∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ 4993) ⌝ ⦄ := by + -- 1. Extract per-element legacy fact: + -- |r[i]| ≤ 3328+1665 ∧ (r[i]*2^16) ≡ array[i] (mod 3329). + have hpre' : ∀ i : Nat, i < 16 → (array.val[i]!).val.natAbs ≤ 3328 * 2 ^ 16 := by + intro i hi + have h := hbound i hi + rwa [show (3328 * 2 ^ 16 : Nat) = 2 ^ 16 * 3328 from by decide] + have hlen' : array.val.length = 16 := hlen + have h_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element.reducing_from_i32_array_spec array out hlen' hpre' + obtain ⟨r0, h_eq, h_per⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + refine ⟨?_, ?_⟩ + · -- (a) Existing FC equation: lift_chunk_mont r = Spec.chunk_reducing_from_i32_array_pure array. + -- Reduce array equality to list equality, then to per-index lift_fe_mont equality. + unfold lift_chunk_mont Spec.chunk_reducing_from_i32_array_pure + apply Subtype.ext + show r0.elements.val.map lift_fe_mont + = (List.range 16).map (fun i => + Spec.mont_reduce_pure (lift_fe_int (array.val[i]!).val)) + have h_r0_len : r0.elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length r0 + apply List.ext_getElem + · simp [List.length_map, List.length_range, h_r0_len] + · intro i hi1 hi2 + have hi : i < 16 := by + have : i < (r0.elements.val.map lift_fe_mont).length := hi1 + simp [List.length_map, h_r0_len] at this; exact this + rw [List.getElem_map] + rw [List.getElem_map, List.getElem_range] + show lift_fe_mont r0.elements.val[i] + = Spec.mont_reduce_pure (lift_fe_int (array.val[i]!).val) + have h_r0_get_eq : r0.elements.val[i] + = r0.elements.val[i]! := by + have hi_r0 : i < r0.elements.val.length := by rw [h_r0_len]; exact hi + rw [getElem!_pos r0.elements.val i hi_r0] + rw [h_r0_get_eq] + obtain ⟨_h_bnd, h_modq⟩ := h_per i hi + exact lift_fe_mont_mont_reduce_pure_eq + (array.val[i]!) (r0.elements.val[i]!) h_modq + · -- (b) Per-lane I16 bound `|r[i]| ≤ 4993` from the legacy spec's first conjunct. + intro i hi + exact (h_per i hi).1 + + +end libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/LoopHelper.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/LoopHelper.lean new file mode 100644 index 00000000..2333f7da --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/LoopHelper.lean @@ -0,0 +1,1049 @@ +/- + # `Util/PortableVector.lean` — Layer-1 elementwise loop infrastructure + + The L1 family (`barrett_reduce`, `montgomery_reduce`, + `montgomery_multiply_by_constant`, `negate`, …) all share the same + shape: a Rust `for i in 0..16 { ... }` loop that reads + `vec.elements[i]`, runs a per-element field-arithmetic primitive (an + L0.x Triple), and writes back to `acc.elements[i]`. + + This module gives a reusable infrastructure layer: + + - `unary_loop_inv` — canonical 2-conjunct loop invariant. + - `unary_loop_body` — canonical body shape (matches every + `vector.portable.arithmetic._loop.body` from Funs.lean). + - `elementwise_unary_step` — per-iteration step lemma. + - `elementwise_unary_spec` — top-level wrapper that invokes + `loop_range_spec_usize`. Each L1.x unary op closes via + `apply elementwise_unary_spec` + supplying the per-element + `@[spec]` as a `per_elem_spec` hypothesis. + + Binary ops (`add`, `sub`) get analogous `binary_loop_*` lemmas + when L1.1/L1.2 land. + + Proof strategy: turn each component of the body + (`IteratorRange.next`, `Array.index_usize`, `per_elem`, + `Array.update`) into a `Result` equation, compose them into a + single body equation, then close via `triple_of_ok_pv`. This is + the cleanest substitute for `mvcgen` when the surrounding spec is + generic in `per_elem` (so mvcgen has no `@[spec]` to register). +-/ +import LibcruxIotMlKem.Util.LoopSpecs +import LibcruxIotMlKem.Extraction.Funs + +open CoreModels Aeneas Aeneas.Std Result ControlFlow Std.Do + +namespace libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper +open libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +/-! ## `FIELD_ELEMENTS_IN_VECTOR` numerical reduction -/ + +theorem field_elements_in_vector_val : + (libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR : Std.Usize).val = 16 := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_ELEMENTS_IN_VECTOR; rfl + +/-! ## Length-of-elements bridge -/ + +@[simp] +theorem PortableVector_elements_length + (v : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + v.elements.length = 16 := by + have := v.elements.property + show v.elements.val.length = 16 + exact this + +/-! ## Local helpers — Triple ↔ Result.ok bridges, pure-prop holds. -/ + +section pv_helpers + +private theorem triple_of_ok_pv + {α : Type} {x : Result α} {v : α} {P : α → Prop} + (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +private theorem triple_exists_ok_pv + {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +private theorem pure_prop_holds_pv {P : Prop} (h : P) : (pure P : Result Prop).holds := by + simp only [Aeneas.Std.Result.holds, Triple, WP.wp]; intro _; exact h + +private theorem of_pure_prop_holds_pv {P : Prop} + (h : (pure P : Result Prop).holds) : P := by + simp only [Aeneas.Std.Result.holds, Triple, WP.wp] at h; exact h trivial + +end pv_helpers + +/-! ## Iterator-next reduction to a `Result` equation. -/ + +/-- `i.val < 16`: `IteratorRange.next` returns `.ok (some i, iter')` with + `iter'.end = 16` and `iter'.start.val = i.val + 1`. We avoid pinning + `iter'.start`'s exact UScalar form by stating the post existentially. -/ +theorem iter_next_some_eq (i : Std.Usize) (h_lt : i.val < (16#usize : Std.Usize).val) : + ∃ s : Std.Usize, s.val = i.val + 1 ∧ + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + = .ok (some i, + ({ start := s, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) := by + have hT := IteratorRange_next_spec_usize i 16#usize + (Q := PostCond.noThrow fun (oi : Option Std.Usize × _) => ⌜ + ∃ s : Std.Usize, s.val = i.val + 1 + ∧ oi = (some i, + ({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝) + (fun _ s hs => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure] + exact ⟨s, hs, rfl⟩) + (fun hge => absurd h_lt (Nat.not_lt.mpr hge)) + obtain ⟨v, hveq, hP⟩ := triple_exists_ok_pv hT + obtain ⟨s, hs_val, hpair⟩ := hP + exact ⟨s, hs_val, by rw [hveq, hpair]⟩ + +/-- `i.val ≥ 16`: `IteratorRange.next` returns `.ok (none, _)`. -/ +theorem iter_next_none_eq (i : Std.Usize) (h_ge : i.val ≥ (16#usize : Std.Usize).val) : + CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := i, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + = .ok ((none : Option Std.Usize), + ({ start := i, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) := by + have hT := IteratorRange_next_spec_usize i 16#usize + (Q := PostCond.noThrow fun (oi : Option Std.Usize × _) => ⌜ + oi = ((none : Option Std.Usize), + ({ start := i, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize)) ⌝) + (fun hlt => absurd hlt (Nat.not_lt.mpr h_ge)) + (fun _ => by + dsimp only [PostCond.noThrow, Std.Do.SPred.down_pure]) + obtain ⟨v, hveq, hP⟩ := triple_exists_ok_pv hT + rw [hveq, hP] + +/-! ## Array index/update reduction to `Result` equations. -/ + +theorem array_index_usize_ok_eq + {α : Type u} {n : Std.Usize} [Inhabited α] + (v : Std.Array α n) (i : Std.Usize) (h_bd : i.val < v.length) : + Aeneas.Std.Array.index_usize v i = .ok (v.val[i.val]!) := by + have hT := Aeneas.Std.Array.index_usize_spec v i h_bd + have h_ex := Aeneas.Std.WP.spec_imp_exists hT + obtain ⟨v', hveq, hPv'⟩ := h_ex + rw [hveq, hPv', getElem!_pos] + +theorem array_update_ok_eq + {α : Type u} {n : Std.Usize} + (v : Std.Array α n) (i : Std.Usize) (x : α) (h_bd : i.val < v.length) : + Aeneas.Std.Array.update v i x = .ok (v.set i x) := by + have hT := Aeneas.Std.Array.update_spec v i x h_bd + have h_ex := Aeneas.Std.WP.spec_imp_exists hT + obtain ⟨v', hveq, hPv'⟩ := h_ex + rw [hveq, hPv'] + +/-! ## Unary loop invariant -/ + +/-- 2-conjunct invariant: + - For `j < k`, `acc.elements[j]` equals the per-elem-op output `r` + for input `input.elements[j]` (carrying the per-elem predicate + `P` that the L0.x `@[spec]` produces). + - For `j ≥ k`, `acc.elements[j] = input.elements[j]`. -/ +def unary_loop_inv + (per_elem : Std.I16 → Result Std.I16) + (P : Std.I16 → Std.I16 → Prop) + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector → + Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + ∃ r, per_elem (input.elements.val[j]!) = .ok r + ∧ acc.elements.val[j]! = r ∧ P (input.elements.val[j]!) r) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.elements.val[j]! = input.elements.val[j]!)) + +/-! ## Unary loop body (canonical shape from Funs.lean) -/ + +def unary_loop_body + (per_elem : Std.I16 → Result Std.I16) + (iter : CoreModels.core.ops.range.Range Std.Usize) + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) := do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep iter + match o with + | core.option.Option.None => ok (done vec) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize vec.elements i + let vi ← per_elem i1 + let a ← Aeneas.Std.Array.update vec.elements i vi + ok (cont (iter1, { elements := a })) + +/-! ## Step lemma — reduces the body to a `Result` equation and closes via `triple_of_ok_pv`. + +The step lemma's post is stated via a top-level `def` rather than an inline +`match`. Reason: an inline `match` in two different declarations (the step +lemma here and the `loop_range_spec_usize` step hypothesis at the call site) +generates *distinct* `match_N` auxiliary constants. Even though the matches +have identical bodies, the kernel sees the constants as different and rejects +the unification. A named `def` is referenced by the same canonical constant +from both sites. -/ + +/-- Per-iteration post for `unary_loop_body`. Identical shape to the + `loop_range_spec_usize` step hypothesis. -/ +def unary_step_post + (per_elem : Std.I16 → Result Std.I16) + (P : Std.I16 → Std.I16 → Prop) + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (unary_loop_inv per_elem P input iter'.start acc').holds + | .done y => (unary_loop_inv per_elem P input 16#usize y).holds + +set_option maxHeartbeats 4000000 in +theorem elementwise_unary_step + (per_elem : Std.I16 → Result Std.I16) + (P : Std.I16 → Std.I16 → Prop) + (per_elem_spec : + ∀ (x : Std.I16), + ⦃ ⌜ True ⌝ ⦄ per_elem x ⦃ ⇓ r => ⌜ P x r ⌝ ⦄) + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (acc : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (unary_loop_inv per_elem P input k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + unary_loop_body per_elem { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ unary_step_post per_elem P input k r ⌝ ⦄ := by + obtain ⟨h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_pv h_inv + have h_acc_len : acc.elements.length = 16 := PortableVector_elements_length acc + have h_16 : (16#usize : Std.Usize).val = 16 := rfl + unfold unary_loop_body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some i = k branch. + have hk_16 : k.val < 16 := by rw [h_16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + have h_idx : + Aeneas.Std.Array.index_usize acc.elements k = .ok (acc.elements.val[k.val]!) := + array_index_usize_ok_eq acc.elements k (by rw [h_acc_len]; exact hk_16) + obtain ⟨r, h_per_eq, h_per_P⟩ := + triple_exists_ok_pv (per_elem_spec (acc.elements.val[k.val]!)) + have h_upd : + Aeneas.Std.Array.update acc.elements k r + = .ok (acc.elements.set k r) := + array_update_ok_eq acc.elements k r (by rw [h_acc_len]; exact hk_16) + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let vi ← per_elem i1 + let a ← Aeneas.Std.Array.update acc.elements i vi + ok (cont (iter1, { elements := a }))) + = .ok (cont + (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k r })) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [bind_tc_ok] + rw [h_idx] + simp only [bind_tc_ok] + rw [h_per_eq] + simp only [bind_tc_ok] + rw [h_upd] + rfl + apply triple_of_ok_pv h_body + show unary_step_post per_elem P input k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k r })) + unfold unary_step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (unary_loop_inv per_elem P input s + { elements := acc.elements.set k r }).holds + apply pure_prop_holds_pv + refine ⟨?_, ?_⟩ + · intro j hj + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · obtain ⟨r_j, h_per_j, h_acc_j, h_P_j⟩ := h_acc_done j hj_lt_k + refine ⟨r_j, h_per_j, ?_, h_P_j⟩ + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : (acc.elements.set k r)[j]! = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j r h_ne + have : (acc.elements.set k r).val[j]! = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.elements.set k r).val[j]! = r_j + rw [this]; exact h_acc_j + · subst hj_eq_k + refine ⟨r, ?_, ?_, ?_⟩ + · have h_eq : acc.elements.val[k.val]! = input.elements.val[k.val]! := + h_acc_undone k.val (Nat.le_refl _) hk_16 + rw [← h_eq]; exact h_per_eq + · have h_lt'' : k.val < acc.elements.length := by rw [h_acc_len]; exact hk_16 + have h_set_eq : (acc.elements.set k r)[k.val]! = r := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.elements k k.val r ⟨rfl, h_lt''⟩ + have : (acc.elements.set k r).val[k.val]! = r := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show (acc.elements.set k r).val[k.val]! = r + exact this + · have h_eq : acc.elements.val[k.val]! = input.elements.val[k.val]! := + h_acc_undone k.val (Nat.le_refl _) hk_16 + rw [← h_eq]; exact h_per_P + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : (acc.elements.set k r)[j]! = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j r h_ne + have : (acc.elements.set k r).val[j]! = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.elements.set k r).val[j]! = input.elements.val[j]! + rw [this] + exact h_acc_undone j h_ge' hj_lt + · -- None branch. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h_16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let vi ← per_elem i1 + let a ← Aeneas.Std.Array.update acc.elements i vi + ok (cont (iter1, { elements := a }))) + = .ok (done acc) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_pv h_body + show unary_step_post per_elem P input k (.done acc) + unfold unary_step_post + show (unary_loop_inv per_elem P input 16#usize acc).holds + apply pure_prop_holds_pv + refine ⟨?_, ?_⟩ + · intro j hj + apply h_acc_done j + rw [hk_eq]; rw [h_16] at hj; exact hj + · intro j hj_ge hj_lt + apply h_acc_undone j _ hj_lt + rw [hk_eq]; rw [h_16] at hj_ge; exact hj_ge + +/-! ## Top-level unary elementwise spec wrapper -/ + +set_option maxHeartbeats 2000000 in +theorem elementwise_unary_spec + (per_elem : Std.I16 → Result Std.I16) + (P : Std.I16 → Std.I16 → Prop) + (per_elem_spec : + ∀ (x : Std.I16), + ⦃ ⌜ True ⌝ ⦄ per_elem x ⦃ ⇓ r => ⌜ P x r ⌝ ⦄) + (input : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + ⦃ ⌜ True ⌝ ⦄ + loop (fun p => unary_loop_body per_elem p.1 p.2) + (({ start := 0#usize, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), input) + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + ∃ ri, per_elem (input.elements.val[i]!) = .ok ri + ∧ r.elements.val[i]! = ri + ∧ P (input.elements.val[i]!) ri ⌝ ⦄ := by + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, vec1) => unary_loop_body per_elem iter1 vec1) + input 0#usize 16#usize + (unary_loop_inv per_elem P input) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_pv ⟨ + fun j hj => by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0] at hj; exact absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · -- PostCond entailment. + rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_done, _h_undone⟩ := of_pure_prop_holds_pv h + intro j hj + apply h_done j + show j < (16#usize : Std.Usize).val + exact hj + · -- Step lemma. We bridge `loop_range_spec_usize`'s inline `match`-based + -- post with `elementwise_unary_step`'s `unary_step_post`-based post via + -- a direct Triple weakening: both are propositionally identical on + -- every result, so a value-level case-split on `r` discharges the + -- entailment. + intro acc k h_ge h_le hinv + have h_step := elementwise_unary_step per_elem P per_elem_spec input acc k h_le hinv + -- Convert via Triple post-equivalence (`Std.Do.Triple.of_entails_right`). + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + -- hh : ⌜ unary_step_post per_elem P input k r ⌝.down + -- Goal: (the lambda's match) r. + rcases r with ⟨iter', acc'⟩ | y + · -- cont branch. + have hP : unary_step_post per_elem P input k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [unary_step_post] using hP + · -- done branch. + have hP : unary_step_post per_elem P input k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [unary_step_post] using hP + +/-! ## Binary loop body / invariant / step / spec. + +Mirror of the unary family but with **two** input vectors. Only `lhs` is +the loop accumulator; `rhs` is captured in the body lambda. The per-element +op now has type `I16 → I16 → Result I16` and reads from both inputs at the +same index `i` before writing back to `acc.elements[i]`. + +The bind chain inside the body has one extra `index_usize` step for `rhs` +compared to `unary_loop_body`, but the structure is otherwise identical. -/ + +/-- Binary loop body: reads `acc.elements[i]` and `rhs.elements[i]`, + applies `per_elem`, writes back to `acc.elements[i]`. -/ +def binary_loop_body + (per_elem : Std.I16 → Std.I16 → Result Std.I16) + (rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (iter : CoreModels.core.ops.range.Range Std.Usize) + (acc : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) := do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep iter + match o with + | core.option.Option.None => ok (done acc) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let i2 ← Aeneas.Std.Array.index_usize rhs.elements i + let vi ← per_elem i1 i2 + let a ← Aeneas.Std.Array.update acc.elements i vi + ok (cont (iter1, { elements := a })) + +/-- 2-conjunct binary invariant: + - For `j < k`, `acc.elements[j]` equals the per-elem-op output `r` + for inputs `input_lhs.elements[j]` and `input_rhs.elements[j]`. + - For `j ≥ k`, `acc.elements[j] = input_lhs.elements[j]` (rhs is + read-only, so its invariant is implicit). -/ +def binary_loop_inv + (per_elem : Std.I16 → Std.I16 → Result Std.I16) + (P : Std.I16 → Std.I16 → Std.I16 → Prop) + (input_lhs input_rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector → + Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + ∃ r, per_elem (input_lhs.elements.val[j]!) (input_rhs.elements.val[j]!) = .ok r + ∧ acc.elements.val[j]! = r + ∧ P (input_lhs.elements.val[j]!) (input_rhs.elements.val[j]!) r) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.elements.val[j]! = input_lhs.elements.val[j]!)) + +/-- Per-iteration post for `binary_loop_body`. -/ +def binary_step_post + (per_elem : Std.I16 → Std.I16 → Result Std.I16) + (P : Std.I16 → Std.I16 → Std.I16 → Prop) + (input_lhs input_rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (binary_loop_inv per_elem P input_lhs input_rhs iter'.start acc').holds + | .done y => (binary_loop_inv per_elem P input_lhs input_rhs 16#usize y).holds + +set_option maxHeartbeats 4000000 in +theorem elementwise_binary_step + (per_elem : Std.I16 → Std.I16 → Result Std.I16) + (P : Std.I16 → Std.I16 → Std.I16 → Prop) + (per_elem_spec : + ∀ (x y : Std.I16), + ⦃ ⌜ True ⌝ ⦄ per_elem x y ⦃ ⇓ r => ⌜ P x y r ⌝ ⦄) + (input_lhs input_rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (acc : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (binary_loop_inv per_elem P input_lhs input_rhs k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + binary_loop_body per_elem input_rhs { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ binary_step_post per_elem P input_lhs input_rhs k r ⌝ ⦄ := by + obtain ⟨h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_pv h_inv + have h_acc_len : acc.elements.length = 16 := PortableVector_elements_length acc + have h_rhs_len : input_rhs.elements.length = 16 := PortableVector_elements_length input_rhs + have h_16 : (16#usize : Std.Usize).val = 16 := rfl + unfold binary_loop_body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some i = k branch. + have hk_16 : k.val < 16 := by rw [h_16] at h_lt; exact h_lt + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + have h_idx_lhs : + Aeneas.Std.Array.index_usize acc.elements k = .ok (acc.elements.val[k.val]!) := + array_index_usize_ok_eq acc.elements k (by rw [h_acc_len]; exact hk_16) + have h_idx_rhs : + Aeneas.Std.Array.index_usize input_rhs.elements k + = .ok (input_rhs.elements.val[k.val]!) := + array_index_usize_ok_eq input_rhs.elements k (by rw [h_rhs_len]; exact hk_16) + obtain ⟨r, h_per_eq, h_per_P⟩ := + triple_exists_ok_pv (per_elem_spec (acc.elements.val[k.val]!) + (input_rhs.elements.val[k.val]!)) + have h_upd : + Aeneas.Std.Array.update acc.elements k r + = .ok (acc.elements.set k r) := + array_update_ok_eq acc.elements k r (by rw [h_acc_len]; exact hk_16) + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let i2 ← Aeneas.Std.Array.index_usize input_rhs.elements i + let vi ← per_elem i1 i2 + let a ← Aeneas.Std.Array.update acc.elements i vi + ok (cont (iter1, { elements := a }))) + = .ok (cont + (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k r })) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [bind_tc_ok] + rw [h_idx_lhs] + simp only [bind_tc_ok] + rw [h_idx_rhs] + simp only [bind_tc_ok] + rw [h_per_eq] + simp only [bind_tc_ok] + rw [h_upd] + rfl + apply triple_of_ok_pv h_body + show binary_step_post per_elem P input_lhs input_rhs k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k r })) + unfold binary_step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (binary_loop_inv per_elem P input_lhs input_rhs s + { elements := acc.elements.set k r }).holds + apply pure_prop_holds_pv + refine ⟨?_, ?_⟩ + · intro j hj + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · obtain ⟨r_j, h_per_j, h_acc_j, h_P_j⟩ := h_acc_done j hj_lt_k + refine ⟨r_j, h_per_j, ?_, h_P_j⟩ + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : (acc.elements.set k r)[j]! = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j r h_ne + have : (acc.elements.set k r).val[j]! = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.elements.set k r).val[j]! = r_j + rw [this]; exact h_acc_j + · subst hj_eq_k + refine ⟨r, ?_, ?_, ?_⟩ + · have h_eq : acc.elements.val[k.val]! = input_lhs.elements.val[k.val]! := + h_acc_undone k.val (Nat.le_refl _) hk_16 + rw [← h_eq]; exact h_per_eq + · have h_lt'' : k.val < acc.elements.length := by rw [h_acc_len]; exact hk_16 + have h_set_eq : (acc.elements.set k r)[k.val]! = r := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.elements k k.val r ⟨rfl, h_lt''⟩ + have : (acc.elements.set k r).val[k.val]! = r := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show (acc.elements.set k r).val[k.val]! = r + exact this + · have h_eq : acc.elements.val[k.val]! = input_lhs.elements.val[k.val]! := + h_acc_undone k.val (Nat.le_refl _) hk_16 + rw [← h_eq]; exact h_per_P + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : (acc.elements.set k r)[j]! = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j r h_ne + have : (acc.elements.set k r).val[j]! = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.elements.set k r).val[j]! = input_lhs.elements.val[j]! + rw [this] + exact h_acc_undone j h_ge' hj_lt + · -- None branch. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h_16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Aeneas.Std.Array.index_usize acc.elements i + let i2 ← Aeneas.Std.Array.index_usize input_rhs.elements i + let vi ← per_elem i1 i2 + let a ← Aeneas.Std.Array.update acc.elements i vi + ok (cont (iter1, { elements := a }))) + = .ok (done acc) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_pv h_body + show binary_step_post per_elem P input_lhs input_rhs k (.done acc) + unfold binary_step_post + show (binary_loop_inv per_elem P input_lhs input_rhs 16#usize acc).holds + apply pure_prop_holds_pv + refine ⟨?_, ?_⟩ + · intro j hj + apply h_acc_done j + rw [hk_eq]; rw [h_16] at hj; exact hj + · intro j hj_ge hj_lt + apply h_acc_undone j _ hj_lt + rw [hk_eq]; rw [h_16] at hj_ge; exact hj_ge + +/-! ## Top-level binary elementwise spec wrapper -/ + +set_option maxHeartbeats 2000000 in +theorem elementwise_binary_spec + (per_elem : Std.I16 → Std.I16 → Result Std.I16) + (P : Std.I16 → Std.I16 → Std.I16 → Prop) + (per_elem_spec : + ∀ (x y : Std.I16), + ⦃ ⌜ True ⌝ ⦄ per_elem x y ⦃ ⇓ r => ⌜ P x y r ⌝ ⦄) + (input_lhs input_rhs : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + ⦃ ⌜ True ⌝ ⦄ + loop (fun p => binary_loop_body per_elem input_rhs p.1 p.2) + (({ start := 0#usize, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), input_lhs) + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + ∃ ri, per_elem (input_lhs.elements.val[i]!) (input_rhs.elements.val[i]!) = .ok ri + ∧ r.elements.val[i]! = ri + ∧ P (input_lhs.elements.val[i]!) (input_rhs.elements.val[i]!) ri ⌝ ⦄ := by + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, vec1) => binary_loop_body per_elem input_rhs iter1 vec1) + input_lhs 0#usize 16#usize + (binary_loop_inv per_elem P input_lhs input_rhs) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_pv ⟨ + fun j hj => by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0] at hj; exact absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_done, _h_undone⟩ := of_pure_prop_holds_pv h + intro j hj + apply h_done j + show j < (16#usize : Std.Usize).val + exact hj + · intro acc k h_ge h_le hinv + have h_step := + elementwise_binary_step per_elem P per_elem_spec input_lhs input_rhs acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : binary_step_post per_elem P input_lhs input_rhs k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [binary_step_post] using hP + · have hP : binary_step_post per_elem P input_lhs input_rhs k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [binary_step_post] using hP + +/-! ## I/O loop body / invariant / step / spec. + +Mirror of the unary family but with **separate input and output types**. +The input is a `Slice Std.I32` (read-only, captured by the body lambda), +and the loop accumulator is a `PortableVector` (Array I16 16). The +per-element op has type `Std.I32 → Result Std.I16` and reads from the +slice at index `i` before writing back to `acc.elements[i]`. + +The slice has no static length, so a precondition +`h_len : 16 ≤ input.val.length` is carried through to discharge the +`Slice.index_usize` bound check. -/ + +/-! ### Slice-index reduction to a `Result` equation. -/ + +/-- `Slice.index_usize` returns `.ok (v.val[i.val]!)` when `i.val < v.length`. -/ +theorem slice_index_usize_ok_eq + {α : Type u} [Inhabited α] + (v : Aeneas.Std.Slice α) (i : Std.Usize) (h_bd : i.val < v.val.length) : + Slice.index_usize v i = .ok (v.val[i.val]!) := by + have h_bd' : i.val < v.length := by + show i.val < v.val.length + exact h_bd + have hT := Slice.index_usize_spec v i h_bd' + have h_ex := Aeneas.Std.WP.spec_imp_exists hT + obtain ⟨v', hveq, hPv'⟩ := h_ex + rw [hveq, hPv', getElem!_pos] + +/-! ### I/O loop body (canonical shape from Funs.lean) -/ + +/-- I/O loop body: reads `input.val[i]!` (a `Slice Std.I32`), applies + `per_elem`, writes back to `acc.elements[i]` (a `PortableVector`). -/ +def io_loop_body + (per_elem : Std.I32 → Result Std.I16) + (input : Aeneas.Std.Slice Std.I32) + (iter : CoreModels.core.ops.range.Range Std.Usize) + (acc : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) := do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep iter + match o with + | core.option.Option.None => ok (done acc) + | core.option.Option.Some i => + let i1 ← Slice.index_usize input i + let i2 ← per_elem i1 + let a ← Aeneas.Std.Array.update acc.elements i i2 + ok (cont (iter1, { elements := a })) + +/-- 2-conjunct I/O loop invariant: + - For `j < k`, `acc.elements[j]` equals the per-elem-op output `r` + for input `input.val[j]!` (carrying the per-elem predicate `P`). + - For `j ≥ k`, no claim is made on `acc.elements[j]` (the original + `out` value is preserved unchanged). -/ +def io_loop_inv + (per_elem : Std.I32 → Result Std.I16) + (P : Std.I32 → Std.I16 → Prop) + (input : Aeneas.Std.Slice Std.I32) + (out : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : + Std.Usize → + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector → + Result Prop := + fun k acc => pure ( + (∀ j : Nat, j < k.val → + ∃ r, per_elem (input.val[j]!) = .ok r + ∧ acc.elements.val[j]! = r ∧ P (input.val[j]!) r) + ∧ (∀ j : Nat, k.val ≤ j → j < 16 → + acc.elements.val[j]! = out.elements.val[j]!)) + +/-- Per-iteration post for `io_loop_body`. -/ +def io_step_post + (per_elem : Std.I32 → Result Std.I16) + (P : Std.I32 → Std.I16 → Prop) + (input : Aeneas.Std.Slice Std.I32) + (out : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (r : ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) : Prop := + match r with + | .cont (iter', acc') => + k.val < (16#usize : Std.Usize).val ∧ iter'.«end» = 16#usize + ∧ iter'.start.val = k.val + 1 + ∧ (io_loop_inv per_elem P input out iter'.start acc').holds + | .done y => (io_loop_inv per_elem P input out 16#usize y).holds + +set_option maxHeartbeats 4000000 in +theorem elementwise_io_step + (per_elem : Std.I32 → Result Std.I16) + (P : Std.I32 → Std.I16 → Prop) + (per_elem_spec : + ∀ (x : Std.I32), + ⦃ ⌜ True ⌝ ⦄ per_elem x ⦃ ⇓ r => ⌜ P x r ⌝ ⦄) + (input : Aeneas.Std.Slice Std.I32) + (h_len : 16 ≤ input.val.length) + (out acc : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Std.Usize) + (h_le : k.val ≤ (16#usize : Std.Usize).val) + (h_inv : (io_loop_inv per_elem P input out k acc).holds) : + ⦃ ⌜ True ⌝ ⦄ + io_loop_body per_elem input { start := k, «end» := 16#usize } acc + ⦃ ⇓ r => ⌜ io_step_post per_elem P input out k r ⌝ ⦄ := by + obtain ⟨h_acc_done, h_acc_undone⟩ := of_pure_prop_holds_pv h_inv + have h_acc_len : acc.elements.length = 16 := PortableVector_elements_length acc + have h_16 : (16#usize : Std.Usize).val = 16 := rfl + unfold io_loop_body + by_cases h_lt : k.val < (16#usize : Std.Usize).val + · -- Some i = k branch. + have hk_16 : k.val < 16 := by rw [h_16] at h_lt; exact h_lt + have hk_input : k.val < input.val.length := by omega + obtain ⟨s, hs_val, h_iter_some⟩ := iter_next_some_eq k h_lt + have h_idx : + Slice.index_usize input k = .ok (input.val[k.val]!) := + slice_index_usize_ok_eq input k hk_input + obtain ⟨r, h_per_eq, h_per_P⟩ := + triple_exists_ok_pv (per_elem_spec (input.val[k.val]!)) + have h_upd : + Aeneas.Std.Array.update acc.elements k r + = .ok (acc.elements.set k r) := + array_update_ok_eq acc.elements k r (by rw [h_acc_len]; exact hk_16) + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Slice.index_usize input i + let i2 ← per_elem i1 + let a ← Aeneas.Std.Array.update acc.elements i i2 + ok (cont (iter1, { elements := a }))) + = .ok (cont + (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k r })) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_some] + simp only [bind_tc_ok] + rw [h_idx] + simp only [bind_tc_ok] + rw [h_per_eq] + simp only [bind_tc_ok] + rw [h_upd] + rfl + apply triple_of_ok_pv h_body + show io_step_post per_elem P input out k + (.cont (({ start := s, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), + { elements := acc.elements.set k r })) + unfold io_step_post + refine ⟨h_lt, rfl, hs_val, ?_⟩ + show (io_loop_inv per_elem P input out s + { elements := acc.elements.set k r }).holds + apply pure_prop_holds_pv + refine ⟨?_, ?_⟩ + · intro j hj + rw [hs_val] at hj + rcases Nat.lt_succ_iff_lt_or_eq.mp hj with hj_lt_k | hj_eq_k + · obtain ⟨r_j, h_per_j, h_acc_j, h_P_j⟩ := h_acc_done j hj_lt_k + refine ⟨r_j, h_per_j, ?_, h_P_j⟩ + have h_ne : k.val ≠ j := Nat.ne_of_gt hj_lt_k + have h_set_ne : (acc.elements.set k r)[j]! = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j r h_ne + have : (acc.elements.set k r).val[j]! = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.elements.set k r).val[j]! = r_j + rw [this]; exact h_acc_j + · subst hj_eq_k + refine ⟨r, h_per_eq, ?_, h_per_P⟩ + have h_lt'' : k.val < acc.elements.length := by rw [h_acc_len]; exact hk_16 + have h_set_eq : (acc.elements.set k r)[k.val]! = r := + Aeneas.Std.Array.getElem!_Nat_set_eq acc.elements k k.val r ⟨rfl, h_lt''⟩ + have : (acc.elements.set k r).val[k.val]! = r := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_eq + show (acc.elements.set k r).val[k.val]! = r + exact this + · intro j hj_ge hj_lt + rw [hs_val] at hj_ge + have h_ne : k.val ≠ j := by omega + have h_ge' : k.val ≤ j := by omega + have h_set_ne : (acc.elements.set k r)[j]! = (acc.elements)[j]! := + Aeneas.Std.Array.getElem!_Nat_set_ne acc.elements k j r h_ne + have : (acc.elements.set k r).val[j]! = acc.elements.val[j]! := by + simpa [Aeneas.Std.Array.getElem!_Nat_eq] using h_set_ne + show (acc.elements.set k r).val[j]! = out.elements.val[j]! + rw [this] + exact h_acc_undone j h_ge' hj_lt + · -- None branch. + have hk_ge : k.val ≥ (16#usize : Std.Usize).val := Nat.not_lt.mp h_lt + have hk_eq : k.val = 16 := by rw [h_16] at hk_ge; omega + have h_iter_none := iter_next_none_eq k hk_ge + have h_body : + (do + let (o, iter1) ← + core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize) + match o with + | core.option.Option.None => + (Result.ok (ControlFlow.done acc) : + Result (ControlFlow + ((CoreModels.core.ops.range.Range Std.Usize) + × libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector)) + | core.option.Option.Some i => + let i1 ← Slice.index_usize input i + let i2 ← per_elem i1 + let a ← Aeneas.Std.Array.update acc.elements i i2 + ok (cont (iter1, { elements := a }))) + = .ok (done acc) := by + conv_lhs => + rw [show + (core.ops.range.Range.Insts.CoreIterTraitsIteratorIterator.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + = (CoreModels.core.iter.range.IteratorRange.next + core.Usize.Insts.CoreIterRangeStep + ({ start := k, «end» := 16#usize } : CoreModels.core.ops.range.Range Std.Usize)) + from rfl] + rw [h_iter_none]; rfl + apply triple_of_ok_pv h_body + show io_step_post per_elem P input out k (.done acc) + unfold io_step_post + show (io_loop_inv per_elem P input out 16#usize acc).holds + apply pure_prop_holds_pv + refine ⟨?_, ?_⟩ + · intro j hj + apply h_acc_done j + rw [hk_eq]; rw [h_16] at hj; exact hj + · intro j hj_ge hj_lt + apply h_acc_undone j _ hj_lt + rw [hk_eq]; rw [h_16] at hj_ge; exact hj_ge + +/-! ### Top-level I/O elementwise spec wrapper -/ + +set_option maxHeartbeats 2000000 in +theorem elementwise_io_spec + (per_elem : Std.I32 → Result Std.I16) + (P : Std.I32 → Std.I16 → Prop) + (per_elem_spec : + ∀ (x : Std.I32), + ⦃ ⌜ True ⌝ ⦄ per_elem x ⦃ ⇓ r => ⌜ P x r ⌝ ⦄) + (input : Aeneas.Std.Slice Std.I32) + (out : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (h_len : 16 ≤ input.val.length) : + ⦃ ⌜ True ⌝ ⦄ + loop (fun p => io_loop_body per_elem input p.1 p.2) + (({ start := 0#usize, «end» := 16#usize } + : CoreModels.core.ops.range.Range Std.Usize), out) + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → + ∃ ri, per_elem (input.val[i]!) = .ok ri + ∧ r.elements.val[i]! = ri + ∧ P (input.val[i]!) ri ⌝ ⦄ := by + apply Std.Do.Triple.of_entails_right _ + (loop_range_spec_usize + (fun (iter1, vec1) => io_loop_body per_elem input iter1 vec1) + out 0#usize 16#usize + (io_loop_inv per_elem P input out) + (by decide : (0#usize : Std.Usize).val ≤ (16#usize : Std.Usize).val) + (pure_prop_holds_pv ⟨ + fun j hj => by + have h0 : (0#usize : Std.Usize).val = 0 := rfl + rw [h0] at hj; exact absurd hj (Nat.not_lt_zero j), + fun _ _ _ => rfl⟩) + ?_) + · rw [PostCond.entails_noThrow] + intro r h + obtain ⟨h_done, _h_undone⟩ := of_pure_prop_holds_pv h + intro j hj + apply h_done j + show j < (16#usize : Std.Usize).val + exact hj + · intro acc k h_ge h_le hinv + have h_step := + elementwise_io_step per_elem P per_elem_spec input h_len out acc k h_le hinv + apply Std.Do.Triple.of_entails_right _ h_step + rw [PostCond.entails_noThrow] + intro r hh + rcases r with ⟨iter', acc'⟩ | y + · have hP : io_step_post per_elem P input out k (.cont (iter', acc')) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [io_step_post] using hP + · have hP : io_step_post per_elem P input out k (.done y) := by + simpa [Std.Do.SPred.down_pure] using hh + simpa [io_step_post] using hP + +end libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/PerElement.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/PerElement.lean new file mode 100644 index 00000000..b16b2aac --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Arithmetic/PerElement.lean @@ -0,0 +1,1661 @@ +/- + # `Equivalence/L0_FieldArith.lean` — Layer 0 field-arithmetic Triples + + `@[spec]` Triples for the leaf field-arithmetic primitives from + `vector/portable/arithmetic.rs`: + + - L0.1 `get_n_least_significant_bits_spec` + - L0.2 `barrett_reduce_element_spec` + - L0.3 `montgomery_reduce_element_spec` — signed Montgomery reduction. + - L0.4 `montgomery_multiply_fe_by_fer_spec` (trivial corollary of L0.3) +-/ +import LibcruxIotMlKem.Extraction.Funs +import LibcruxIotMlKem.Spec.Montgomery +import LibcruxIotMlKem.Vector.Portable.Arithmetic.BvMasks +import LibcruxIotMlKem.Spec.Lift + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks + +/-! ## Local primitive helpers + + Two specs missing from upstream Aeneas Std at the pinned rev: a + BV-level spec for `IScalar >>> UScalar` and the post-unfold value + bridge for `IScalar.wrapping_mul`. Both are PR-ready upstream + candidates (SKILL §Tier 2); kept local pending bump. +-/ + +/-- The Triple `⦃True⦄ x ⦃⇓ r => ⌜P r⌝⦄` closer for `x = .ok v`. + Lifts a pure-Prop fact about the value into a Triple post. -/ +private theorem triple_of_ok_l0 {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +/-- Extract the `.ok` witness from a true-pre Triple — mirror of the + SKILL §13.5 helper, scoped to this file. Used by L0.4 to consume + L0.3's `@[spec]` without reaching into L0.3's privates. -/ +private theorem triple_exists_ok_l0 {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- BV-level spec for `IScalar.shiftRight_UScalar` — the + arithmetic-shift-right operation on signed integers. The `bv` + representation is `BitVec.sshiftRight`; on `Int` this is + floor-division by `2^s.val` (matches `Int.shiftRight`). -/ +theorem IScalar.shiftRight_UScalar_bv_eq + {ty : Aeneas.Std.IScalarTy} {tys : Aeneas.Std.UScalarTy} + (x : Aeneas.Std.IScalar ty) (s : Aeneas.Std.UScalar tys) + (hs : s.val < ty.numBits) : + Aeneas.Std.IScalar.shiftRight_UScalar x s = .ok ⟨x.bv.sshiftRight s.val⟩ := by + simp only [Aeneas.Std.IScalar.shiftRight_UScalar, Aeneas.Std.IScalar.shiftRight] + rw [if_pos hs] + +-- `modq_R_to_169` (old↔new Montgomery modq form bridge) moved to +-- `LibcruxIotMlKem.Spec.Montgomery`; referenced unqualified below +-- via the `open libcrux_iot_ml_kem.Util` declaration at the top of +-- this file. + +/-! ## L0.1 — `get_n_least_significant_bits_spec` + + Implements the upstream + `Vector.Portable.Arithmetic.get_n_least_significant_bits` correctness + `value & ((1 <<< n) - 1)`; the postcondition asserts the resulting + Nat is in `[0, 2^n)` and equals `value.val % 2^n.val`. +-/ + +/-- The `do`-block reduces to `Result.ok ⟨value.bv &&& ((1#32 <<< n.val) - 1#32)⟩` + under the precondition `n.val ≤ 16` (which implies `n.val < 32`). -/ +private theorem get_n_least_significant_bits_eq_ok + (n : Std.U8) (value : Std.U32) (hn : n.val ≤ 16) : + libcrux_iot_ml_kem.vector.portable.arithmetic.get_n_least_significant_bits n value + = .ok ⟨value.bv &&& ((1#32 <<< n.val) - 1#32)⟩ := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.get_n_least_significant_bits + -- n.val < 32 since n.val ≤ 16 < 32. + have hn_lt : n.val < Aeneas.Std.UScalarTy.U32.numBits := by + have h_red : (Aeneas.Std.UScalarTy.U32.numBits : Nat) = 32 := by decide + rw [h_red]; omega + -- Unfold the shift-left and the bind. + simp only [HShiftLeft.hShiftLeft, Aeneas.Std.UScalar.shiftLeft_UScalar, + Aeneas.Std.UScalar.shiftLeft, hn_lt, reduceIte, + CoreModels.core.num.U32.wrapping_sub, + rust_primitives.arithmetic.wrapping_sub_u32, + Aeneas.Std.bind_tc_ok] + rfl + +@[spec] +theorem get_n_least_significant_bits_spec + (n : Std.U8) (value : Std.U32) (hn : n.val ≤ 16) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.get_n_least_significant_bits n value + ⦃ ⇓ r => ⌜ r.val < 2 ^ n.val ∧ r.val = value.val % (2 ^ n.val) ⌝ ⦄ := by + apply triple_of_ok_l0 (v := ⟨value.bv &&& ((1#32 <<< n.val) - 1#32)⟩) + (get_n_least_significant_bits_eq_ok n value hn) + -- Two conjuncts: bound and modulo identity. Reduce both to Nat-level claims + -- about `(value.bv &&& (1 <<< n - 1)).toNat`. + have hn_lt : n.val < 32 := by omega + have h_pow_pos : 0 < (2 : Nat) ^ n.val := Nat.two_pow_pos _ + -- The mask `(1#32 <<< n.val) - 1#32 : BitVec 32` has `.toNat = 2^n.val - 1`. + -- Discharge by case analysis on n.val ∈ {0, …, 16} — each case is a concrete BV decide. + have h_mask_toNat : ((1#32 <<< n.val) - 1#32).toNat = 2 ^ n.val - 1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks.mask_pow2_minus_one_toNat n.val hn + -- r.val = (value.bv &&& mask_bv).toNat = value.val &&& (2^n.val - 1) + have h_r_val : (⟨value.bv &&& (1#32 <<< n.val - 1#32)⟩ : Std.U32).val + = value.val &&& (2 ^ n.val - 1) := by + show (value.bv &&& (1#32 <<< n.val - 1#32)).toNat = _ + rw [BitVec.toNat_and, h_mask_toNat]; rfl + refine ⟨?_, ?_⟩ + · -- Bound: r.val < 2^n.val. + rw [h_r_val] + have h_and_le : value.val &&& (2 ^ n.val - 1) ≤ 2 ^ n.val - 1 := Nat.and_le_right + -- `2^n.val ≥ 1`, so `2^n.val - 1 < 2^n.val`, and the `&&&` is ≤ the mask. + scalar_tac + · -- Mod identity: r.val = value.val % 2^n.val. + rw [h_r_val] + exact Nat.and_two_pow_sub_one_eq_mod value.val n.val + +/-! ## L0.2 — `barrett_reduce_element_spec` + + Implements the upstream `Vector.Portable.Arithmetic.barrett_reduce_element` + a Barrett-style quotient `q = (value * 20159 + 2^25) >>> 26` (in i32), + then returns `value - q * 3329` (in i16). The post asserts the result + is congruent to `value` mod 3329 and bounded by 3328 in absolute value. +-/ + +/-- Closed-form `Int` evaluation of the Barrett quotient (before + casting to i16 and multiplying by 3329). + + `barrett_q v = (v * 20159 + 2^25) / 2^26`. + + Used as the pivot between the BV-level extraction and the pure-Int + arithmetic bound. -/ +private def barrett_q (v : Int) : Int := + (v * 20159 + (2^25 : Int)) / (2^26 : Int) + +/-- **Pure `Int`-level core of Barrett reduction.** + + Given `|value| ≤ 32767`, the residual `value - barrett_q value * 3329` + is congruent to `value` mod 3329 (trivially, since the difference is + a multiple of 3329) and has absolute value at most 3328. -/ +private theorem barrett_reduce_core + (v : Int) (h_v : v.natAbs ≤ 32767) : + let q := barrett_q v + let r := v - q * 3329 + modq_eq r v 3329 ∧ r.natAbs ≤ 3328 := by + -- |v| ≤ 32767 as an Int. + have h_v_abs : |v| ≤ (32767 : Int) := by + rw [Int.abs_eq_natAbs]; exact_mod_cast h_v + have h_v_lb : -(32767 : Int) ≤ v := (abs_le.mp h_v_abs).1 + have h_v_ub : v ≤ (32767 : Int) := (abs_le.mp h_v_abs).2 + -- Closed-form of barrett_q: (v * 20159 + 2^25) / 2^26. + set s : Int := v * 20159 + (2^25 : Int) with hs_def + set q : Int := s / (2^26 : Int) with hq_def + set r : Int := v - q * 3329 with hr_def + refine ⟨?_, ?_⟩ + · -- modq_eq r v 3329 = (r - v) % 3329 = 0 = (- (q * 3329)) % 3329 = 0. + show (r - v) % 3329 = 0 + have h_eq : r - v = -(q * 3329) := by show v - q * 3329 - v = _; ring + rw [h_eq] + rw [show -(q * 3329) = (-q) * 3329 by ring] + exact Int.mul_emod_left _ _ + · -- Bound: |r| ≤ 3328. Strategy: use the Barrett keystone 20159 * 3329 = 2^26 + 447 to + -- express r * 2^26 = (ρ - 2^25) * 3329 - v * 447 where ρ = s % 2^26 ∈ [0, 2^26). + -- Then bound both terms and conclude |r| ≤ 3328 (the actual bound is ≤ 1665). + have h_keystone : (20159 * 3329 : Int) = (2^26 + 447 : Int) := by decide + have h_rho_lb : (0 : Int) ≤ s % (2^26 : Int) := Int.emod_nonneg s (by decide) + have h_rho_ub : s % (2^26 : Int) < (2^26 : Int) := Int.emod_lt_of_pos s (by decide) + have h_s_decomp : s = q * (2^26 : Int) + s % (2^26 : Int) := by + have h := Int.emod_add_mul_ediv s (2^26 : Int) + -- h : s % 2^26 + 2^26 * (s / 2^26) = s + show s = s / (2^26 : Int) * (2^26 : Int) + s % (2^26 : Int) + have h_eq : (2^26 : Int) * (s / (2^26 : Int)) + = s / (2^26 : Int) * (2^26 : Int) := by ring + omega + set ρ : Int := s % (2^26 : Int) with hρ_def + -- Key identity: r * 2^26 = (ρ - 2^25) * 3329 - v * 447. + have h_r_mul : r * (2^26 : Int) = (ρ - 2^25) * 3329 - v * 447 := by + have h1 : v * 20159 + (2^25 : Int) = q * (2^26 : Int) + ρ := by + rw [← hs_def]; exact h_s_decomp + have h2 : v * 20159 = q * (2^26 : Int) + ρ - 2^25 := by + have : v * 20159 + (2^25 : Int) - 2^25 = q * (2^26 : Int) + ρ - 2^25 := by + rw [h1] + have h_simp : v * 20159 + (2^25 : Int) - 2^25 = v * 20159 := by ring + rw [h_simp] at this; exact this + -- Multiply h2 by 3329 and apply keystone. + have h3 : v * (20159 * 3329) = (q * (2^26 : Int) + ρ - 2^25) * 3329 := by + have h_lhs : v * 20159 * 3329 = v * (20159 * 3329) := by ring + calc v * (20159 * 3329) + = v * 20159 * 3329 := by ring + _ = (q * (2^26 : Int) + ρ - 2^25) * 3329 := by rw [h2] + rw [h_keystone] at h3 + -- h3 : v * (2^26 + 447) = (q * 2^26 + ρ - 2^25) * 3329 + -- Rearrange: r * 2^26 = (ρ - 2^25) * 3329 - v * 447. + have h4 : (v - q * 3329) * (2^26 : Int) + v * 447 + = (ρ - 2^25) * 3329 := by + have h_rhs : v * (2^26 + 447 : Int) = v * 2^26 + v * 447 := by ring + rw [h_rhs] at h3 + have h_expand : (q * (2^26 : Int) + ρ - 2^25) * 3329 + = q * 3329 * 2^26 + (ρ - 2^25) * 3329 := by ring + rw [h_expand] at h3 + -- h3 : v * 2^26 + v * 447 = q * 3329 * 2^26 + (ρ - 2^25) * 3329 + -- We want: (v - q*3329) * 2^26 + v * 447 = (ρ - 2^25) * 3329 + -- i.e., v * 2^26 - q * 3329 * 2^26 + v * 447 = (ρ - 2^25) * 3329, which follows. + have h_lhs : (v - q * 3329) * (2^26 : Int) + v * 447 + = v * 2^26 + v * 447 - q * 3329 * 2^26 := by ring + rw [h_lhs] + omega + -- Rearrange h4: r * 2^26 = (ρ - 2^25) * 3329 - v * 447. + have : r * (2^26 : Int) = (v - q * 3329) * (2^26 : Int) := by + show (v - q * 3329) * _ = _; rfl + rw [this] + omega + -- Bounds on r * 2^26 from h_r_mul: + -- (ρ - 2^25) * 3329 ∈ [-(2^25 * 3329), 2^25 * 3329 - 3329] (since ρ ∈ [0, 2^26) i.e. ρ-2^25 ∈ [-2^25, 2^25-1]) + -- v * 447 ∈ [-(32767 * 447), 32767 * 447] + have h_rho_diff_lb : (-(2^25) : Int) ≤ ρ - (2^25 : Int) := by + have : (0 : Int) ≤ ρ := h_rho_lb + omega + have h_rho_diff_ub : ρ - (2^25 : Int) ≤ ((2^25) - 1 : Int) := by + have : ρ < (2^26 : Int) := h_rho_ub + omega + have h_term1_lb : (-(2^25 * 3329) : Int) ≤ (ρ - 2^25) * 3329 := by + have h := mul_le_mul_of_nonneg_right h_rho_diff_lb (by decide : (0 : Int) ≤ 3329) + have h_rearr : -(2^25 : Int) * 3329 = -(2^25 * 3329) := by ring + rw [h_rearr] at h; exact h + have h_term1_ub : (ρ - 2^25) * 3329 ≤ ((2^25 - 1) * 3329 : Int) := by + exact mul_le_mul_of_nonneg_right h_rho_diff_ub (by decide : (0 : Int) ≤ 3329) + have h_term2_lb : (-(32767 * 447) : Int) ≤ v * 447 := by + have h := mul_le_mul_of_nonneg_right h_v_lb (by decide : (0 : Int) ≤ 447) + have h_rearr : -(32767 : Int) * 447 = -(32767 * 447) := by ring + rw [h_rearr] at h; exact h + have h_term2_ub : v * 447 ≤ ((32767 * 447) : Int) := + mul_le_mul_of_nonneg_right h_v_ub (by decide : (0 : Int) ≤ 447) + -- Derive numerical bounds on r * 2^26. + have h_r_mul_lb : (-(2^25 * 3329 + 32767 * 447 : Int)) ≤ r * (2^26 : Int) := by + rw [h_r_mul] + have h_t1 : -(2^25 * 3329 : Int) ≤ (ρ - 2^25) * 3329 := h_term1_lb + have h_t2 : v * 447 ≤ ((32767 * 447) : Int) := h_term2_ub + omega + have h_r_mul_ub : r * (2^26 : Int) ≤ (((2^25 - 1) * 3329 + 32767 * 447 : Int)) := by + rw [h_r_mul] + have h_t1 : (ρ - 2^25) * 3329 ≤ ((2^25 - 1) * 3329 : Int) := h_term1_ub + have h_t2 : -(32767 * 447 : Int) ≤ v * 447 := h_term2_lb + omega + -- Conclude |r| ≤ 3328 by contradiction (numerical chase). + have h_pow_pos : (0 : Int) < 2^26 := by decide + have h_r_lb : (-3328 : Int) ≤ r := by + by_contra h_neg + push Not at h_neg + have h_r_le : r ≤ -3329 := by omega + have h_mul_le : r * (2^26 : Int) ≤ (-3329) * (2^26 : Int) := by + have h_neg3329_le : (-3329 : Int) * (2^26 : Int) ≥ r * (2^26 : Int) := by + have := mul_le_mul_of_nonneg_right h_r_le (le_of_lt h_pow_pos) + exact this + omega + have h_const : ((-3329) * (2^26 : Int)) < -(2^25 * 3329 + 32767 * 447 : Int) := by + decide + omega + have h_r_ub : r ≤ (3328 : Int) := by + by_contra h_pos + push Not at h_pos + have h_r_ge : (3329 : Int) ≤ r := by omega + have h_mul_ge : ((3329) * (2^26 : Int)) ≤ r * (2^26 : Int) := + mul_le_mul_of_nonneg_right h_r_ge (le_of_lt h_pow_pos) + have h_const : (((2^25 - 1) * 3329 + 32767 * 447 : Int)) < (3329 * (2^26 : Int)) := by + decide + omega + have h_abs_le : |r| ≤ (3328 : Int) := abs_le.mpr ⟨h_r_lb, h_r_ub⟩ + have h_abs_natAbs : |r| = (r.natAbs : Int) := Int.abs_eq_natAbs r + rw [h_abs_natAbs] at h_abs_le + exact_mod_cast h_abs_le + +/-- Closed-form value computed by the Barrett-reduction impl, as an `IScalar.I16`. + + Stages the BV-level result of unfolding `barrett_reduce_element` so the + Triple proof can apply `triple_of_ok_l0` against it. Mirrors L0.3's + `mont_reduce_impl_value`. Exposed (non-private) for L1.3 totality use. -/ +def barrett_reduce_impl_value (value : Std.I16) : Std.I16 := + let i : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 value + let i1 : Std.I32 := Aeneas.Std.I32.wrapping_mul i (20159#i32) + let i3 : Std.I32 := ⟨(1#i32 : Std.I32).bv.shiftLeft 26 |>.sshiftRight 1⟩ + let t : Std.I32 := Aeneas.Std.I32.wrapping_add i1 i3 + let i5 : Std.I32 := ⟨t.bv.sshiftRight 26⟩ + let quotient : Std.I16 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i5 + let i6 : Std.I16 := Aeneas.Std.I16.wrapping_mul quotient (3329#i16) + Aeneas.Std.I16.wrapping_sub value i6 + +/-- The `do`-block reduces to `Result.ok (barrett_reduce_impl_value value)`. + + Exposed (non-private) so that L1.3 `barrett_reduce_spec` can establish + totality of `barrett_reduce_element` independent of the per-element + bound precondition. -/ +theorem barrett_reduce_element_eq_ok (value : Std.I16) : + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_element value + = .ok (barrett_reduce_impl_value value) := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_element + unfold barrett_reduce_impl_value + -- Unfold the Barrett constants. + have h_mult : libcrux_iot_ml_kem.vector.portable.arithmetic.BARRETT_MULTIPLIER = 20159#i32 := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.BARRETT_MULTIPLIER; rfl + have h_q : libcrux_iot_ml_kem.vector.traits.FIELD_MODULUS = 3329#i16 := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_MODULUS; rfl + have h_shift : libcrux_iot_ml_kem.vector.traits.BARRETT_SHIFT = 26#i32 := by + unfold libcrux_iot_ml_kem.vector.traits.BARRETT_SHIFT; rfl + have h_R : libcrux_iot_ml_kem.vector.traits.BARRETT_R + = .ok (⟨(1#i32 : Std.I32).bv.shiftLeft 26⟩ : Std.I32) := by + unfold libcrux_iot_ml_kem.vector.traits.BARRETT_R + rw [h_shift] + show (1#i32 : Std.I32) <<< (26#i32 : Std.I32) = _ + show Aeneas.Std.IScalar.shiftLeft_IScalar (1#i32) (26#i32) = _ + unfold Aeneas.Std.IScalar.shiftLeft_IScalar + rw [if_pos (by decide : (26#i32 : Std.I32).val ≥ 0)] + unfold Aeneas.Std.IScalar.shiftLeft + have h_lt : (26#i32 : Std.I32).toNat < Aeneas.Std.IScalarTy.I32.numBits := by decide + rw [if_pos h_lt] + rfl + -- (i2 >>> 1#i32) is `IScalar.shiftRight_IScalar`. + have h_one_pos : (1#i32 : Std.I32).val ≥ 0 := by decide + have h_one_lt : (1#i32 : Std.I32).toNat < Aeneas.Std.IScalarTy.I32.numBits := by decide + -- (t >>> i4) where i4 = 26#u32 is `IScalar.shiftRight_UScalar`. + have h_i4_val : (Aeneas.Std.IScalar.hcast Aeneas.Std.UScalarTy.U32 (26#i32 : Std.I32)).val = 26 := by + decide + have h_i4_lt : (Aeneas.Std.IScalar.hcast Aeneas.Std.UScalarTy.U32 (26#i32 : Std.I32)).val + < Aeneas.Std.IScalarTy.I32.numBits := by + rw [h_i4_val]; decide + simp only [libcrux_secrets.traits.Classify.Blanket.classify, + libcrux_secrets.traits.Declassify.Blanket.declassify, + libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32, + libcrux_secrets.I32.Insts.Libcrux_secretsIntCastOps.as_i16, + Aeneas.Std.bind_tc_ok, Aeneas.Std.lift, + CoreModels.core.num.I32.wrapping_mul, + CoreModels.core.num.I32.wrapping_add, + CoreModels.core.num.I16.wrapping_mul, + CoreModels.core.num.I16.wrapping_sub, + rust_primitives.arithmetic.wrapping_mul_i32, + rust_primitives.arithmetic.wrapping_add_i32, + rust_primitives.arithmetic.wrapping_mul_i16, + rust_primitives.arithmetic.wrapping_sub_i16, + h_mult, h_q, h_shift, h_R] + -- Reduce the >>> by 1#i32 and by i4=26#u32. + simp only [HShiftRight.hShiftRight, + Aeneas.Std.IScalar.shiftRight_IScalar, + Aeneas.Std.IScalar.shiftRight_UScalar, + Aeneas.Std.IScalar.shiftRight, + h_one_pos, h_one_lt, h_i4_val, reduceIte] + rfl + +/-- Bridge: `(barrett_reduce_impl_value value).val = value.val - barrett_q value.val * 3329` + (as `Int`), under `|value.val| ≤ 32767`. -/ +private theorem barrett_reduce_impl_value_val + (value : Std.I16) (hb : value.val.natAbs ≤ 32767) : + (barrett_reduce_impl_value value).val + = value.val - barrett_q value.val * 3329 := by + unfold barrett_reduce_impl_value barrett_q + -- Set up the input value and key bounds. + set v : Int := value.val with hv_def + have h_v_abs : |v| ≤ (32767 : Int) := by + rw [Int.abs_eq_natAbs]; exact_mod_cast hb + have h_v_lb : -(32767 : Int) ≤ v := (abs_le.mp h_v_abs).1 + have h_v_ub : v ≤ (32767 : Int) := (abs_le.mp h_v_abs).2 + -- (cast .I32 value).val = v (since |v| ≤ 32767 < 2^31). + have h_i_val : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 value).val = v := by + apply Aeneas.Std.IScalar.val_mod_pow_inBounds + · -- -2^31 ≤ v + have h_red : (Aeneas.Std.IScalarTy.I32.numBits - 1) = 31 := by decide + rw [h_red] + have h_const : -(2 : Int)^31 ≤ -(32767 : Int) := by decide + have : -(32767 : Int) ≤ v := h_v_lb + omega + · -- v < 2^31 + have h_red : (Aeneas.Std.IScalarTy.I32.numBits - 1) = 31 := by decide + rw [h_red] + have h_const : (32767 : Int) < (2 : Int)^31 := by decide + have : v ≤ (32767 : Int) := h_v_ub + omega + -- (20159#i32 : I32).val = 20159. + have h_20159 : (20159#i32 : Std.I32).val = 20159 := by decide + -- i1 = wrapping_mul i 20159. i1.val = bmod (v * 20159) (2^32) = v * 20159 (since |v * 20159| < 2^31). + set i : Std.I32 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 value + set i1 : Std.I32 := Aeneas.Std.I32.wrapping_mul i (20159#i32) + have h_v20_abs : |v * 20159| ≤ (32767 * 20159 : Int) := by + rw [abs_mul, show |(20159 : Int)| = 20159 from by decide] + have h_v_abs' : |v| ≤ (32767 : Int) := h_v_abs + have h_nn : (0 : Int) ≤ 20159 := by decide + exact mul_le_mul_of_nonneg_right h_v_abs' h_nn + have h_v20_lb : -(2 : Int)^31 ≤ v * 20159 := by + have h_const : -(2 : Int)^31 ≤ -(32767 * 20159 : Int) := by decide + have h_le : -(32767 * 20159 : Int) ≤ v * 20159 := (abs_le.mp h_v20_abs).1 + omega + have h_v20_ub : v * 20159 < (2 : Int)^31 := by + have h_const : (32767 * 20159 : Int) < (2 : Int)^31 := by decide + have h_le : v * 20159 ≤ (32767 * 20159 : Int) := (abs_le.mp h_v20_abs).2 + omega + have h_i1_val : i1.val = v * 20159 := by + show (Aeneas.Std.I32.wrapping_mul _ _).val = _ + rw [Aeneas.Std.I32.wrapping_mul_val_eq, h_i_val, h_20159] + apply Arith.Int.bmod_pow2_eq_of_inBounds' 32 _ (by decide) + · -- -2^31 ≤ v * 20159 + have h_red : ((2 : Int)^(32-1)) = (2 : Int)^31 := by decide + rw [h_red]; exact h_v20_lb + · -- v * 20159 < 2^31 + have h_red : ((2 : Int)^(32-1)) = (2 : Int)^31 := by decide + rw [h_red]; exact h_v20_ub + -- i3 := ((1 <<< 26) sshiftRight 1).toInt = 2^25. + have h_i3_val : ((⟨(1#i32 : Std.I32).bv.shiftLeft 26 |>.sshiftRight 1⟩ : Std.I32).val) + = (2^25 : Int) := by decide + -- t = wrapping_add i1 i3. |i1.val + i3.val| = |v * 20159 + 2^25| < 2^31. + set i3 : Std.I32 := ⟨(1#i32 : Std.I32).bv.shiftLeft 26 |>.sshiftRight 1⟩ + set t : Std.I32 := Aeneas.Std.I32.wrapping_add i1 i3 + have h_sum_lb : -(2 : Int)^31 ≤ v * 20159 + 2^25 := by + have h_const : -(2 : Int)^31 ≤ -(32767 * 20159 : Int) + 2^25 := by decide + have h_le : -(32767 * 20159 : Int) ≤ v * 20159 := (abs_le.mp h_v20_abs).1 + omega + have h_sum_ub : v * 20159 + 2^25 < (2 : Int)^31 := by + have h_const : (32767 * 20159 : Int) + 2^25 < (2 : Int)^31 := by decide + have h_le : v * 20159 ≤ (32767 * 20159 : Int) := (abs_le.mp h_v20_abs).2 + omega + have h_t_val : t.val = v * 20159 + 2^25 := by + show (Aeneas.Std.I32.wrapping_add _ _).val = _ + rw [Aeneas.Std.I32.wrapping_add_val_eq, h_i1_val, h_i3_val] + apply Arith.Int.bmod_pow2_eq_of_inBounds' 32 _ (by decide) + · have h_red : ((2 : Int)^(32-1)) = (2 : Int)^31 := by decide + rw [h_red]; exact h_sum_lb + · have h_red : ((2 : Int)^(32-1)) = (2 : Int)^31 := by decide + rw [h_red]; exact h_sum_ub + -- i5 = ⟨t.bv.sshiftRight 26⟩. i5.val = t.val / 2^26. + set i5 : Std.I32 := ⟨t.bv.sshiftRight 26⟩ + have h_i5_val : i5.val = t.val / (2^26 : Int) := by + show (t.bv.sshiftRight 26).toInt = _ + rw [BitVec.toInt_sshiftRight, Int.shiftRight_eq_div_pow] + have h_pow_nat : ((2^26 : Nat) : Int) = ((2 : Int)^26) := by push_cast + rw [h_pow_nat] + show t.bv.toInt / _ = t.val / _ + rfl + -- Bounds on i5.val: i5.val = t.val / 2^26 ∈ [-10, 10] (since t.val ∈ [-32767*20159+2^25, 32767*20159+2^25]). + have h_i5_bounds : -(2^15 : Int) ≤ i5.val ∧ i5.val < (2^15 : Int) := by + rw [h_i5_val, h_t_val] + -- -32767*20159 + 2^25 ≤ t.val ≤ 32767*20159 + 2^25 + have h_t_lb : -(32767 * 20159 : Int) + 2^25 ≤ v * 20159 + 2^25 := by + have h_le : -(32767 * 20159 : Int) ≤ v * 20159 := (abs_le.mp h_v20_abs).1 + omega + have h_t_ub : v * 20159 + 2^25 ≤ (32767 * 20159 : Int) + 2^25 := by + have h_le : v * 20159 ≤ (32767 * 20159 : Int) := (abs_le.mp h_v20_abs).2 + omega + refine ⟨?_, ?_⟩ + · have h := Int.ediv_le_ediv (a := -(32767 * 20159 : Int) + 2^25) + (b := v * 20159 + 2^25) (c := (2^26 : Int)) (by decide) h_t_lb + have h_const : (-(32767 * 20159 : Int) + 2^25) / (2^26 : Int) = -10 := by decide + rw [h_const] at h + have h_2_15 : -(2 : Int)^15 ≤ -10 := by decide + omega + · have h := Int.ediv_le_ediv (a := v * 20159 + 2^25) + (b := (32767 * 20159 : Int) + 2^25) (c := (2^26 : Int)) (by decide) h_t_ub + have h_const : ((32767 * 20159 : Int) + 2^25) / (2^26 : Int) = 10 := by decide + rw [h_const] at h + have h_2_15 : (10 : Int) < (2 : Int)^15 := by decide + omega + -- quotient = cast .I16 i5. quotient.val = i5.val. + set quotient : Std.I16 := Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i5 + have h_quotient_val : quotient.val = i5.val := by + apply Aeneas.Std.IScalar.val_mod_pow_inBounds + · have h_red : (Aeneas.Std.IScalarTy.I16.numBits - 1) = 15 := by decide + rw [h_red]; exact h_i5_bounds.1 + · have h_red : (Aeneas.Std.IScalarTy.I16.numBits - 1) = 15 := by decide + rw [h_red]; exact h_i5_bounds.2 + -- 3329#i16 .val = 3329 + have h_3329 : (3329#i16 : Std.I16).val = 3329 := by decide + -- i6 = wrapping_mul quotient 3329. i6.val = Int.bmod (i5.val * 3329) (2^16). + -- NOTE: with K=32767, |i5.val| ≤ 10 and 10*3329 = 33290 > 2^15, so the multiplication + -- may wrap. We use the bmod form directly rather than claiming i6.val = i5.val * 3329. + set i6 : Std.I16 := Aeneas.Std.I16.wrapping_mul quotient (3329#i16) + have h_i6_val_bmod : i6.val = Int.bmod (i5.val * 3329) (2^16) := by + show (Aeneas.Std.I16.wrapping_mul _ _).val = _ + rw [Aeneas.Std.I16.wrapping_mul_val_eq, h_quotient_val, h_3329] + -- The barrett_q closed form match: i5.val = (v * 20159 + 2^25) / 2^26 = barrett_q v. + have h_i5_eq_q : i5.val = (v * 20159 + (2^25 : Int)) / (2^26 : Int) := by + rw [h_i5_val, h_t_val] + -- Final: result = wrapping_sub value i6. + -- result.val = Int.bmod (v - i6.val) (2^16) + -- = Int.bmod (v - Int.bmod (i5.val * 3329) (2^16)) (2^16) + -- = Int.bmod (v - i5.val * 3329) (2^16) [bmod congruence] + -- = v - q * 3329 [since |v - q*3329| ≤ 3328 < 2^15]. + have h_core := barrett_reduce_core v hb + set q : Int := barrett_q v with hq_def + have h_q_eq_i5 : q = i5.val := by + unfold barrett_q at hq_def + rw [hq_def, h_i5_eq_q] + have h_core_bound : (v - q * 3329).natAbs ≤ 3328 := h_core.2 + have h_core_bound_int : |v - q * 3329| ≤ (3328 : Int) := by + have h_abs : |v - q * 3329| = ((v - q * 3329).natAbs : Int) := Int.abs_eq_natAbs _ + rw [h_abs]; exact_mod_cast h_core_bound + show (Aeneas.Std.I16.wrapping_sub value i6).val = _ + rw [Aeneas.Std.I16.wrapping_sub_val_eq, h_i6_val_bmod] + -- Goal: Int.bmod (v - Int.bmod (i5.val * 3329) (2^16)) (2^16) = v - q * 3329. + -- Step 1: eliminate inner bmod using congruence (a - bmod(b)(n) ≡ a - b mod n). + have h_bmod_elim : Int.bmod (v - Int.bmod (i5.val * 3329) (2^16)) (2^16) + = Int.bmod (v - i5.val * 3329) (2^16) := + Int.sub_bmod_bmod + rw [h_bmod_elim] + -- Rewrite RHS divisor to ↑i5 then ↑i5 to q so both sides involve `q * 3329`. + rw [show (v * 20159 + (2^25 : Int)) / (2^26 : Int) = i5.val from h_i5_eq_q.symm] + rw [← h_q_eq_i5] + -- Step 2: Int.bmod (v - q * 3329) (2^16) = v - q * 3329 since |v - q*3329| ≤ 3328 < 2^15. + apply Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · -- -2^15 ≤ v - q * 3329 from |v - q*3329| ≤ 3328. + have h_red : ((2 : Int)^(16-1)) = (2 : Int)^15 := by decide + rw [h_red] + have h_lb : -(3328 : Int) ≤ v - q * 3329 := (abs_le.mp h_core_bound_int).1 + have h_const : -(2 : Int)^15 ≤ -(3328 : Int) := by decide + omega + · -- v - q * 3329 < 2^15 from |v - q*3329| ≤ 3328. + have h_red : ((2 : Int)^(16-1)) = (2 : Int)^15 := by decide + rw [h_red] + have h_ub : v - q * 3329 ≤ (3328 : Int) := (abs_le.mp h_core_bound_int).2 + have h_const : (3328 : Int) < (2 : Int)^15 := by decide + omega + +/-! ### L0.2 Triple. -/ + +@[spec] +theorem barrett_reduce_element_spec + (value : Std.I16) (hb : value.val.natAbs ≤ 32767) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_element value + ⦃ ⇓ r => ⌜ modq_eq r.val value.val 3329 + ∧ r.val.natAbs ≤ 3328 ⌝ ⦄ := by + apply triple_of_ok_l0 (v := barrett_reduce_impl_value value) + (barrett_reduce_element_eq_ok value) + -- Two conjuncts: congruence and bound. + rw [barrett_reduce_impl_value_val value hb] + -- Goal: modq_eq (value.val - barrett_q value.val * 3329) value.val 3329 + -- ∧ (value.val - barrett_q value.val * 3329).natAbs ≤ 3328. + exact barrett_reduce_core value.val hb + +/-! ## L0.3 — `montgomery_reduce_element_spec` + + Implements the upstream `Vector.Portable.Arithmetic.montgomery_reduce_element` +-/ + +/-! ### Auxiliary `Int`-level Montgomery reduction (the L0.3 mathematical core) + + The Triple proof below threads the impl through `IScalar.cast` / + `wrapping_mul` / `>>>` and discharges the resulting `Int`-equation + via this single helper. Keeps the Triple body short. -/ + +/-- The closed integer formula that the impl computes for the + Montgomery-reduced value, expressed in terms of the input `v` + and the truncated multiplier `v16 := Int.bmod v (2^16)`. + + Used internally by the L0.3 Triple proof. The bound and the + congruence are the two halves of the L0.3 postcondition. -/ +private theorem mont_reduce_core + (v : Int) (h_v : v.natAbs ≤ 2^16 * 3328) : + let v16 := Int.bmod v (2^16) + let k16 := Int.bmod (v16 * 62209) (2^16) + let km := k16 * 3329 + let res := v / (2^16 : Int) - km / (2^16 : Int) + modq_eq (res * (2^16 : Int)) v 3329 ∧ res.natAbs ≤ 3328 + 1665 := by + -- Standard bmod bounds for power-of-two: + -- |bmod x (2^16)| ≤ 2^15, more precisely `-2^15 ≤ x ≤ 2^15 - 1`. + have h_v16_lb : -(2^15 : Int) ≤ Int.bmod v (2^16) := by + have := (Arith.Int.bmod_pow2_bounds 16 v).1; simpa using this + have h_v16_ub : Int.bmod v (2^16) < (2^15 : Int) := by + have := (Arith.Int.bmod_pow2_bounds 16 v).2; simpa using this + have h_k16_lb : -(2^15 : Int) ≤ Int.bmod (Int.bmod v (2^16) * 62209) (2^16) := by + have := (Arith.Int.bmod_pow2_bounds 16 (Int.bmod v (2^16) * 62209)).1 + simpa using this + have h_k16_ub : Int.bmod (Int.bmod v (2^16) * 62209) (2^16) < (2^15 : Int) := by + have := (Arith.Int.bmod_pow2_bounds 16 (Int.bmod v (2^16) * 62209)).2 + simpa using this + -- |v| ≤ 3328 * 2^16 + have h_v_abs : -((2^16 : Int) * 3328) ≤ v ∧ v ≤ (2^16 : Int) * 3328 := by + have h_nat : (v.natAbs : Int) ≤ ((2^16 * 3328 : Nat) : Int) := by exact_mod_cast h_v + -- |v| = v.natAbs (as Int) + have h_abs : |v| = (v.natAbs : Int) := Int.abs_eq_natAbs v + have h_v_lt_abs : -|v| ≤ v ∧ v ≤ |v| := ⟨neg_abs_le v, le_abs_self v⟩ + refine ⟨?_, ?_⟩ + · have h1 : -(v.natAbs : Int) ≤ v := by rw [← h_abs]; exact h_v_lt_abs.1 + have h2 : ((2^16 * 3328 : Nat) : Int) = (2^16 : Int) * 3328 := by norm_cast + scalar_tac + · have h1 : v ≤ (v.natAbs : Int) := by rw [← h_abs]; exact h_v_lt_abs.2 + have h2 : ((2^16 * 3328 : Nat) : Int) = (2^16 : Int) * 3328 := by norm_cast + scalar_tac + set v16 := Int.bmod v (2^16) + set k16 := Int.bmod (v16 * 62209) (2^16) + set km := k16 * 3329 + -- Now derive (v - km) % 2^16 = 0: + -- km = k16 * 3329; k16 ≡ v16 * 62209 (mod 2^16); + -- so km ≡ v16 * 62209 * 3329 ≡ v16 (mod 2^16) (via 62209*3329 ≡ 1 mod 2^16). + have h_km_mod : (v - km) % (2^16 : Int) = 0 := by + -- km % R = k16 * 3329 % R = (v16 * 62209) * 3329 % R = v16 % R (via keystone). + have h_keystone_int : (62209 * 3329 : Int) % (2^16) = 1 := by decide + -- bmod_emod : Int.bmod x m % m = x % m + have h_k16_emod : k16 % (2^16 : Int) = (v16 * 62209) % (2^16 : Int) := by + change (Int.bmod (v16 * 62209) (2^16)) % (2^16 : Int) = (v16 * 62209) % (2^16 : Int) + exact_mod_cast Int.bmod_emod + have h_step1 : km % (2^16 : Int) = (v16 * (62209 * 3329)) % (2^16 : Int) := by + change (k16 * 3329) % _ = _ + rw [Int.mul_emod, h_k16_emod, ← Int.mul_emod] + congr 1; ring + have h_step2 : km % (2^16 : Int) = v16 % (2^16 : Int) := by + rw [h_step1, Int.mul_emod, h_keystone_int, mul_one, Int.emod_emod_of_dvd _ (dvd_refl _)] + -- v % R = v16 % R via bmod_emod. + have h_v_emod : v % (2^16 : Int) = v16 % (2^16 : Int) := by + change v % (2^16 : Int) = (Int.bmod v (2^16)) % (2^16 : Int) + exact_mod_cast Int.bmod_emod.symm + rw [Int.sub_emod, h_v_emod, ← h_step2]; simp + -- Apply libcrux_iot_ml_kem.Spec.Montgomery.sub_div_of_emod_eq_zero + have h_div_split : v / (2^16 : Int) - km / (2^16 : Int) = (v - km) / (2^16 : Int) := by + exact libcrux_iot_ml_kem.Spec.Montgomery.sub_div_of_emod_eq_zero v km (2^16) (by decide) h_km_mod + refine ⟨?_, ?_⟩ + · -- modq_eq ((v/R - km/R) * R) v 3329, i.e. ((v/R - km/R) * R - v) % 3329 = 0. + show ((v / (2^16 : Int) - km / (2^16 : Int)) * (2^16 : Int) - v) % 3329 = 0 + rw [h_div_split] + -- Since R ∣ (v - km), (v - km)/R * R = v - km. + have h_R_dvd : (2^16 : Int) ∣ (v - km) := Int.dvd_of_emod_eq_zero h_km_mod + obtain ⟨q, hq⟩ := h_R_dvd + have h_vm_div : (v - km) / (2^16 : Int) = q := by + rw [hq]; exact Int.mul_ediv_cancel_left q (by decide) + rw [h_vm_div] + -- v - km = 2^16 * q, so q * 2^16 - v = -km = -(k16 * 3329). + have h_q_eq : q * (2^16 : Int) - v = -km := by + have h1 : v - km = (2 : Int) ^ 16 * q := hq + have h2 : q * (2^16 : Int) - v = -(v - km - (2^16 : Int) * q) + (-km) := by ring + rw [h2, h1]; ring + rw [h_q_eq] + show -(k16 * 3329) % 3329 = 0 + rw [show -(k16 * 3329) = (-k16) * 3329 by ring] + exact Int.mul_emod_left _ _ + · -- res.natAbs ≤ 3328 + 1665. + have h_v_div_bounds : -3328 ≤ v / (2^16 : Int) ∧ v / (2^16 : Int) ≤ 3328 := by + obtain ⟨hl, hu⟩ := h_v_abs + refine ⟨?_, ?_⟩ + · have h := Int.ediv_le_ediv (a := -((2^16 : Int) * 3328)) (b := v) + (c := (2^16 : Int)) (by decide) hl + have h_const : (-(2^16 * 3328) : Int) / (2^16 : Int) = -3328 := by decide + scalar_tac + · have h := Int.ediv_le_ediv (a := v) (b := (2^16 : Int) * 3328) + (c := (2^16 : Int)) (by decide) hu + have h_const : ((2^16 * 3328 : Int)) / (2^16 : Int) = 3328 := by decide + scalar_tac + have h_km_bounds : -(2^15 * 3329 : Int) ≤ km ∧ km ≤ ((2^15 - 1) * 3329 : Int) := by + refine ⟨?_, ?_⟩ + · -- -(2^15) ≤ k16, so -(2^15) * 3329 ≤ k16 * 3329 = km + have h := mul_le_mul_of_nonneg_right h_k16_lb (by decide : (0 : Int) ≤ 3329) + have h_eq : (-(2^15 : Int)) * 3329 = -(2^15 * 3329) := by ring + rw [h_eq] at h; exact h + · -- k16 ≤ 2^15 - 1, so km = k16 * 3329 ≤ (2^15 - 1) * 3329 + have h_k16_le : k16 ≤ 2^15 - 1 := by omega + have h := mul_le_mul_of_nonneg_right h_k16_le (by decide : (0 : Int) ≤ 3329) + exact h + have h_km_div_bounds : -1665 ≤ km / (2^16 : Int) ∧ km / (2^16 : Int) ≤ 1664 := by + obtain ⟨hl, hu⟩ := h_km_bounds + refine ⟨?_, ?_⟩ + · have h := Int.ediv_le_ediv (a := -(2^15 * 3329 : Int)) (b := km) + (c := (2^16 : Int)) (by decide) hl + have h_const : -(2^15 * 3329 : Int) / (2^16 : Int) = -1665 := by decide + scalar_tac + · have h := Int.ediv_le_ediv (a := km) (b := ((2^15 - 1) * 3329 : Int)) + (c := (2^16 : Int)) (by decide) hu + have h_const : (((2^15 - 1) * 3329 : Int)) / (2^16 : Int) = 1664 := by decide + scalar_tac + obtain ⟨h_vl, h_vu⟩ := h_v_div_bounds + obtain ⟨h_kml, h_kmu⟩ := h_km_div_bounds + have h_res_l : -(3328 + 1665 : Int) ≤ v / (2^16 : Int) - km / (2^16 : Int) := by + have := add_le_add h_vl (neg_le_neg h_kmu) + have h_simp : (-3328 : Int) + -1664 = -(3328 + 1665) + 1 := by ring + have h_simp2 : v / (2^16 : Int) + -(km / (2^16 : Int)) = v / (2^16 : Int) - km / (2^16 : Int) := + by ring + have h_chain : -(3328 + 1665 : Int) ≤ -3328 + -1664 := by decide + have := le_trans h_chain this + rw [h_simp2] at this; exact this + have h_res_u : v / (2^16 : Int) - km / (2^16 : Int) ≤ (3328 + 1665 : Int) := by + have := add_le_add h_vu (neg_le_neg h_kml) + have h_simp2 : v / (2^16 : Int) + -(km / (2^16 : Int)) = v / (2^16 : Int) - km / (2^16 : Int) := + by ring + have h_chain : (3328 : Int) + -(-1665) ≤ (3328 + 1665) := by decide + have := le_trans this h_chain + rw [h_simp2] at this; exact this + -- Bridge to natAbs via the |.|-natAbs identity. + have h_abs_eq : (((v / (2^16 : Int) - km / (2^16 : Int)).natAbs : Int)) + = |v / (2^16 : Int) - km / (2^16 : Int)| := by + rw [Int.abs_eq_natAbs] + have h_abs_le : |v / (2^16 : Int) - km / (2^16 : Int)| ≤ (3328 + 1665 : Int) := by + rw [abs_le]; exact ⟨h_res_l, h_res_u⟩ + have h_int_le : ((v / (2^16 : Int) - km / (2^16 : Int)).natAbs : Int) ≤ (3328 + 1665 : Int) := by + rw [h_abs_eq]; exact h_abs_le + scalar_tac + +/-- Closed-form value computed by the impl, as an `IScalar.I16`. + + Exposed (non-private) so that L1.10 `reducing_from_i32_array_spec` + can establish totality of `montgomery_reduce_element` independent + of the per-element bound precondition (mirrors L1.3's use of + `barrett_reduce_impl_value`). -/ +def mont_reduce_impl_value (value : Std.I32) : Std.I16 := + let k := Aeneas.Std.I32.wrapping_mul + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 value)) + (Aeneas.Std.UScalar.hcast Aeneas.Std.IScalarTy.I32 + (62209#u32)) + let km := Aeneas.Std.I32.wrapping_mul + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 k)) + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (3329#i16)) + let i9 := (⟨km.bv.sshiftRight 16⟩ : Std.I32) + let i11 := (⟨value.bv.sshiftRight 16⟩ : Std.I32) + Aeneas.Std.I16.wrapping_sub + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i11) + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i9) + +/-- The `do`-block reduces to `Result.ok (mont_reduce_impl_value value)`. + + Exposed (non-private) so that L1.10 `reducing_from_i32_array_spec` + can establish totality of `montgomery_reduce_element` independent + of the per-element bound precondition. -/ +theorem mont_reduce_element_eq_ok (value : Std.I32) : + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_reduce_element value + = .ok (mont_reduce_impl_value value) := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_reduce_element + unfold mont_reduce_impl_value + -- Unfold the constants: + have h_inv : libcrux_iot_ml_kem.vector.traits.INVERSE_OF_MODULUS_MOD_MONTGOMERY_R = 62209#u32 := by + unfold libcrux_iot_ml_kem.vector.traits.INVERSE_OF_MODULUS_MOD_MONTGOMERY_R; rfl + have h_q : libcrux_iot_ml_kem.vector.traits.FIELD_MODULUS = 3329#i16 := by + unfold libcrux_iot_ml_kem.vector.traits.FIELD_MODULUS; rfl + have h_shift : libcrux_iot_ml_kem.vector.portable.arithmetic.MONTGOMERY_SHIFT = 16#u8 := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.MONTGOMERY_SHIFT; rfl + -- The shift amount as a U32, after cast, has val = 16 < 32. + have h_shift_val : (Aeneas.Std.UScalar.cast Aeneas.Std.UScalarTy.U32 (16#u8 : U8)).val = 16 := by + decide + have h_shift_lt : (Aeneas.Std.UScalar.cast Aeneas.Std.UScalarTy.U32 (16#u8 : U8)).val + < Aeneas.Std.IScalarTy.I32.numBits := by + rw [h_shift_val]; decide + simp only [libcrux_secrets.traits.Classify.Blanket.classify, + libcrux_secrets.traits.Declassify.Blanket.declassify, + libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32, + libcrux_secrets.I32.Insts.Libcrux_secretsIntCastOps.as_i16, + libcrux_secrets.U32.Insts.Libcrux_secretsIntCastOps.as_i32, + Aeneas.Std.bind_tc_ok, Aeneas.Std.lift, + CoreModels.core.num.I32.wrapping_mul, + CoreModels.core.num.I16.wrapping_sub, + rust_primitives.arithmetic.wrapping_mul_i32, + rust_primitives.arithmetic.wrapping_sub_i16, + h_inv, h_q, h_shift] + -- After substitutions the goal should be a do-block of two `>>>` calls + -- followed by `ok`; unfold the >>> instance + the shiftRight definition. + simp only [HShiftRight.hShiftRight, Aeneas.Std.IScalar.shiftRight_UScalar, + Aeneas.Std.IScalar.shiftRight, h_shift_val] + rfl + +/-! ### `.val` of the closed-form impl value, in `Int` terms. + + Used by the Triple proof to bridge BitVec/cast/shift to the + `mont_reduce_core` helper. -/ + +private theorem mont_reduce_impl_value_val + (value : Std.I32) (hb : value.val.natAbs ≤ 2^16 * 3328) : + (mont_reduce_impl_value value).val + = let v16 := Int.bmod value.val (2^16) + let k16 := Int.bmod (v16 * 62209) (2^16) + let km := k16 * 3329 + value.val / (2^16 : Int) - km / (2^16 : Int) := by + unfold mont_reduce_impl_value + -- |value.val| ≤ 3328 · 2^16, so all intermediate I32 operations fit (no wrap). + set v : Int := value.val + set v16 : Int := Int.bmod v (2^16) + -- Bound v + have h_v_abs_int : |v| ≤ (2^16 * 3328 : Int) := by + rw [Int.abs_eq_natAbs] + have : (v.natAbs : Int) ≤ ((2^16 * 3328 : Nat) : Int) := by exact_mod_cast hb + have h2 : ((2^16 * 3328 : Nat) : Int) = (2^16 * 3328 : Int) := by norm_cast + scalar_tac + -- (cast .I16 value).val = bmod v (2^16) = v16 + have h_v16_eq : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 value).val = v16 := by + rw [Aeneas.Std.IScalar.cast_val_eq]; rfl + -- v16 bounds + have h_v16_bounds : -(2^15 : Int) ≤ v16 ∧ v16 < (2^15 : Int) := by + refine ⟨?_, ?_⟩ + · have := (Arith.Int.bmod_pow2_bounds 16 v).1; simpa using this + · have := (Arith.Int.bmod_pow2_bounds 16 v).2; simpa using this + -- (cast .I32 (cast .I16 value)).val = v16 since |v16| < 2^15 < 2^31 + have h_v16_in_i32 : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 value)).val = v16 := by + have h_red : (Aeneas.Std.IScalarTy.I32.numBits - 1) = 31 := by decide + have h_lb : -((2 : Int)^(Aeneas.Std.IScalarTy.I32.numBits - 1)) + ≤ (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 value).val := by + rw [h_red, h_v16_eq] + have h_v16_lb : -(2^15 : Int) ≤ v16 := h_v16_bounds.1 + have h_const : -((2 : Int)^31) ≤ -((2 : Int)^15) := by decide + scalar_tac + have h_ub : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 value).val + < ((2 : Int)^(Aeneas.Std.IScalarTy.I32.numBits - 1)) := by + rw [h_red, h_v16_eq] + have h_v16_ub : v16 < (2^15 : Int) := h_v16_bounds.2 + have h_const : (2 : Int)^15 ≤ (2 : Int)^31 := by decide + scalar_tac + rw [Aeneas.Std.IScalar.val_mod_pow_inBounds _ _ h_lb h_ub] + exact h_v16_eq + -- (UScalar.hcast .I32 (62209#u32)).val = 62209 + have h_62209 : (Aeneas.Std.UScalar.hcast Aeneas.Std.IScalarTy.I32 (62209#u32 : U32)).val + = 62209 := by decide + -- k = wrapping_mul (v16 as I32) (62209 as I32). |v16 * 62209| ≤ 2^15 * 62209 < 2^31. + set k : Std.I32 := Aeneas.Std.I32.wrapping_mul + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 value)) + (Aeneas.Std.UScalar.hcast Aeneas.Std.IScalarTy.I32 (62209#u32)) + -- k.val = v16 * 62209 (no wrap): + -- Using BitVec.toInt_mul: (a*b).toInt = bmod (a.toInt * b.toInt) (2^32); + -- |v16 * 62209| < 2^31 so the bmod is identity. + have h_k_val : k.val = v16 * 62209 := by + show (Aeneas.Std.I32.wrapping_mul _ _).val = _ + -- wrapping_mul_bv_eq : (wrapping_mul x y).bv = x.bv * y.bv + have h_bv : k.bv = (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 value)).bv + * (Aeneas.Std.UScalar.hcast Aeneas.Std.IScalarTy.I32 (62209#u32)).bv := by + show (Aeneas.Std.I32.wrapping_mul _ _).bv = _ + simp only [Aeneas.Std.I32.wrapping_mul, Aeneas.Std.IScalar.wrapping_mul] + -- k.val = k.bv.toInt = bmod (a.bv.toInt * b.bv.toInt) (2^32); + -- and a.bv.toInt = (cast .I32 ...).val = v16, b.bv.toInt = 62209. + show k.bv.toInt = v16 * 62209 + rw [h_bv, BitVec.toInt_mul] + -- IScalarTy.I32.numBits = 32, so bmod (v16 * 62209) (2^32). + show Int.bmod _ _ = _ + have h_a_int : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 value)).bv.toInt = v16 := by + show (Aeneas.Std.IScalar.cast _ _).val = _ + exact h_v16_in_i32 + have h_b_int : (Aeneas.Std.UScalar.hcast Aeneas.Std.IScalarTy.I32 (62209#u32 : U32)).bv.toInt + = 62209 := by + show (Aeneas.Std.UScalar.hcast _ _).val = _ + exact h_62209 + rw [h_a_int, h_b_int] + -- bmod (v16 * 62209) (2^32) = v16 * 62209 since |v16 * 62209| < 2^31 + apply Arith.Int.bmod_pow2_eq_of_inBounds' 32 _ (by decide) + · -- -(2^(32-1)) ≤ v16 * 62209 + have h_lb : (-(2^15 : Int)) * 62209 ≤ v16 * 62209 := + mul_le_mul_of_nonneg_right h_v16_bounds.1 (by decide : (0 : Int) ≤ 62209) + have h_const : -((2 : Int)^(32-1)) ≤ -((2 : Int)^15) * 62209 := by decide + scalar_tac + · -- v16 * 62209 < 2^(32-1). + have h_ub : v16 * 62209 < (2^15 : Int) * 62209 := + mul_lt_mul_of_pos_right h_v16_bounds.2 (by decide : (0 : Int) < 62209) + have h_const : (2^15 : Int) * 62209 ≤ (2 : Int)^(32-1) := by decide + scalar_tac + -- Now (cast .I16 k).val = bmod k.val (2^16) = bmod (v16 * 62209) (2^16) = k16 + set k16 := Int.bmod (v16 * 62209) (2^16) + have h_k16_eq : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 k).val = k16 := by + rw [Aeneas.Std.IScalar.cast_val_eq, h_k_val]; rfl + have h_k16_bounds : -(2^15 : Int) ≤ k16 ∧ k16 < (2^15 : Int) := by + refine ⟨?_, ?_⟩ + · have := (Arith.Int.bmod_pow2_bounds 16 (v16 * 62209)).1; simpa using this + · have := (Arith.Int.bmod_pow2_bounds 16 (v16 * 62209)).2; simpa using this + -- (cast .I32 (cast .I16 k)).val = k16 + have h_k16_in_i32 : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 k)).val = k16 := by + have h_red : (Aeneas.Std.IScalarTy.I32.numBits - 1) = 31 := by decide + have h_lb : -((2 : Int)^(Aeneas.Std.IScalarTy.I32.numBits - 1)) + ≤ (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 k).val := by + rw [h_red, h_k16_eq] + have h_k16_lb : -(2^15 : Int) ≤ k16 := h_k16_bounds.1 + have h_const : -((2 : Int)^31) ≤ -((2 : Int)^15) := by decide + scalar_tac + have h_ub : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 k).val + < ((2 : Int)^(Aeneas.Std.IScalarTy.I32.numBits - 1)) := by + rw [h_red, h_k16_eq] + have h_k16_ub : k16 < (2^15 : Int) := h_k16_bounds.2 + have h_const : (2 : Int)^15 ≤ (2 : Int)^31 := by decide + scalar_tac + rw [Aeneas.Std.IScalar.val_mod_pow_inBounds _ _ h_lb h_ub] + exact h_k16_eq + -- (cast .I32 (3329#i16)).val = 3329 + have h_3329 : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 (3329#i16 : I16)).val + = 3329 := by decide + -- km = wrapping_mul (k16 as I32) (3329 as I32). |k16 * 3329| < 2^15 * 3329 < 2^27 < 2^31. + set km_aeneas : Std.I32 := Aeneas.Std.I32.wrapping_mul + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 k)) + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 (3329#i16)) + have h_km_val : km_aeneas.val = k16 * 3329 := by + have h_bv : km_aeneas.bv = (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 k)).bv + * (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 (3329#i16)).bv := by + show (Aeneas.Std.I32.wrapping_mul _ _).bv = _ + simp only [Aeneas.Std.I32.wrapping_mul, Aeneas.Std.IScalar.wrapping_mul] + show km_aeneas.bv.toInt = _ + rw [h_bv, BitVec.toInt_mul] + show Int.bmod _ _ = _ + rw [show (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 k)).bv.toInt = k16 from h_k16_in_i32, + show (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 (3329#i16 : I16)).bv.toInt = 3329 + from h_3329] + apply Arith.Int.bmod_pow2_eq_of_inBounds' 32 _ (by decide) + · have h_lb := mul_le_mul_of_nonneg_right h_k16_bounds.1 (by decide : (0 : Int) ≤ 3329) + have h_const : -((2 : Int)^(32-1)) ≤ -((2 : Int)^15) * 3329 := by decide + scalar_tac + · have h_ub := mul_lt_mul_of_pos_right h_k16_bounds.2 (by decide : (0 : Int) < 3329) + have h_const : (2^15 : Int) * 3329 ≤ (2 : Int)^(32-1) := by decide + scalar_tac + -- The two arithmetic shifts: i9 = km >> 16, i11 = value >> 16. + -- |km.val| < 2^15 * 3329 < 2^27, so |i9.val| < 2^11 < 2^15 (fits in i16). + -- |value.val| ≤ 3328 * 2^16, so |i11.val| ≤ 3328 < 2^15 (fits in i16). + set i9 : Std.I32 := ⟨km_aeneas.bv.sshiftRight 16⟩ + set i11 : Std.I32 := ⟨value.bv.sshiftRight 16⟩ + -- i9.val = km.val / 2^16 + have h_i9_val : i9.val = km_aeneas.val / (2^16 : Int) := by + show (km_aeneas.bv.sshiftRight 16).toInt = _ + rw [BitVec.toInt_sshiftRight, Int.shiftRight_eq_div_pow] + norm_cast + have h_i11_val : i11.val = value.val / (2^16 : Int) := by + show (value.bv.sshiftRight 16).toInt = _ + rw [BitVec.toInt_sshiftRight, Int.shiftRight_eq_div_pow] + norm_cast + -- Bound i9 and i11 to fit in I16: + -- i9.val = km.val / 2^16, |km.val| ≤ 2^15 * 3329, so |i9.val| ≤ 2^15 * 3329 / 2^16. + have h_i9_bounds : -(2^15 : Int) ≤ i9.val ∧ i9.val < (2^15 : Int) := by + rw [h_i9_val, h_km_val] + have hl : -(2^15 * 3329 : Int) ≤ k16 * 3329 := by + have h_lb := mul_le_mul_of_nonneg_right h_k16_bounds.1 (by decide : (0 : Int) ≤ 3329) + have : (-(2^15 : Int)) * 3329 = -(2^15 * 3329 : Int) := by ring + scalar_tac +nonLin + have hu : k16 * 3329 ≤ ((2^15 - 1) * 3329 : Int) := by + have h_le : k16 ≤ 2^15 - 1 := by omega + exact mul_le_mul_of_nonneg_right h_le (by decide) + refine ⟨?_, ?_⟩ + · have h := Int.ediv_le_ediv (a := -(2^15 * 3329 : Int)) (b := k16 * 3329) + (c := (2^16 : Int)) (by decide) hl + have h_const : -(2^15 * 3329 : Int) / (2^16 : Int) = -1665 := by decide + have : (-1665 : Int) ≥ -(2^15) := by decide + scalar_tac + · have h := Int.ediv_le_ediv (a := k16 * 3329) (b := ((2^15 - 1) * 3329 : Int)) + (c := (2^16 : Int)) (by decide) hu + have h_const : (((2^15 - 1) * 3329 : Int)) / (2^16 : Int) = 1664 := by decide + have : (1664 : Int) < 2^15 := by decide + scalar_tac + have h_i11_bounds : -(2^15 : Int) ≤ i11.val ∧ i11.val < (2^15 : Int) := by + rw [h_i11_val] + have hl : -((2^16 : Int) * 3328) ≤ v := by + have h_nat : (v.natAbs : Int) ≤ ((2^16 * 3328 : Nat) : Int) := by exact_mod_cast hb + have h_abs : |v| = (v.natAbs : Int) := Int.abs_eq_natAbs v + have h_v_lt_abs : -|v| ≤ v := neg_abs_le v + have h2 : ((2^16 * 3328 : Nat) : Int) = (2^16 : Int) * 3328 := by norm_cast + scalar_tac + have hu : v ≤ (2^16 : Int) * 3328 := by + have h_nat : (v.natAbs : Int) ≤ ((2^16 * 3328 : Nat) : Int) := by exact_mod_cast hb + have h_abs : |v| = (v.natAbs : Int) := Int.abs_eq_natAbs v + have h_v_lt_abs : v ≤ |v| := le_abs_self v + have h2 : ((2^16 * 3328 : Nat) : Int) = (2^16 : Int) * 3328 := by norm_cast + scalar_tac + refine ⟨?_, ?_⟩ + · have h := Int.ediv_le_ediv (a := -((2^16 : Int) * 3328)) (b := v) + (c := (2^16 : Int)) (by decide) hl + have h_const : (-((2^16 : Int) * 3328)) / (2^16 : Int) = -3328 := by decide + have : (-3328 : Int) ≥ -(2^15) := by decide + scalar_tac + · have h := Int.ediv_le_ediv (a := v) (b := (2^16 : Int) * 3328) + (c := (2^16 : Int)) (by decide) hu + have h_const : ((2^16 : Int) * 3328) / (2^16 : Int) = 3328 := by decide + have : (3328 : Int) < 2^15 := by decide + scalar_tac + -- (cast .I16 i9).val = i9.val (since |i9.val| < 2^15) + have h_c_val : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i9).val = i9.val := by + have h_lb : -((2 : Int)^(Aeneas.Std.IScalarTy.I16.numBits - 1)) ≤ i9.val := by + have h_red : (Aeneas.Std.IScalarTy.I16.numBits - 1) = 15 := by decide + rw [h_red]; exact h_i9_bounds.1 + have h_ub : i9.val < ((2 : Int)^(Aeneas.Std.IScalarTy.I16.numBits - 1)) := by + have h_red : (Aeneas.Std.IScalarTy.I16.numBits - 1) = 15 := by decide + rw [h_red]; exact h_i9_bounds.2 + rw [Aeneas.Std.IScalar.val_mod_pow_inBounds _ _ h_lb h_ub] + -- (cast .I16 i11).val = i11.val + have h_vh_val : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i11).val = i11.val := by + have h_lb : -((2 : Int)^(Aeneas.Std.IScalarTy.I16.numBits - 1)) ≤ i11.val := by + have h_red : (Aeneas.Std.IScalarTy.I16.numBits - 1) = 15 := by decide + rw [h_red]; exact h_i11_bounds.1 + have h_ub : i11.val < ((2 : Int)^(Aeneas.Std.IScalarTy.I16.numBits - 1)) := by + have h_red : (Aeneas.Std.IScalarTy.I16.numBits - 1) = 15 := by decide + rw [h_red]; exact h_i11_bounds.2 + rw [Aeneas.Std.IScalar.val_mod_pow_inBounds _ _ h_lb h_ub] + -- Wrapping_sub on I16: result.val = bmod (vh.val - c.val) (2^16). + -- We have |vh - c| ≤ 3328 + 1665 < 2^15, so no wrap. + show (Aeneas.Std.I16.wrapping_sub _ _).val = _ + show (Aeneas.Std.I16.wrapping_sub + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i11) + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i9)).bv.toInt = _ + rw [show (Aeneas.Std.I16.wrapping_sub + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i11) + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i9)).bv + = (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i11).bv + - (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i9).bv from rfl, + BitVec.toInt_sub] + rw [show (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i11).bv.toInt = i11.val + from h_vh_val, + show (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I16 i9).bv.toInt = i9.val + from h_c_val] + -- Goal: Int.bmod (i11.val - i9.val) (2^16) = v/R - km/R. + -- Substitute h_i9_val, h_i11_val, h_km_val: + rw [h_i11_val, h_i9_val, h_km_val] + -- Need: bmod (v/R - k16*3329/R) (2^16) = v/R - k16*3329/R, i.e., the diff fits in [-2^15, 2^15). + apply Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · -- -2^15 ≤ v/R - k16*3329/R: bounds give us [-4992, 4993], well within [-2^15, 2^15). + -- We need |v/R - km/R| ≤ 3328 + 1665 = 4993 < 2^15 = 32768. + have h_v_div : -3328 ≤ v / (2^16 : Int) ∧ v / (2^16 : Int) ≤ 3328 := by + refine ⟨?_, ?_⟩ + · have := h_i11_bounds.1; rw [h_i11_val] at this + have h_const : -3328 ≥ -(2^15 : Int) := by decide + -- Stronger bound — re-derive directly. + have hl : -((2^16 : Int) * 3328) ≤ v := by + have h_nat : (v.natAbs : Int) ≤ ((2^16 * 3328 : Nat) : Int) := by exact_mod_cast hb + have h_abs : |v| = (v.natAbs : Int) := Int.abs_eq_natAbs v + have h_v_lt_abs : -|v| ≤ v := neg_abs_le v + have h2 : ((2^16 * 3328 : Nat) : Int) = (2^16 : Int) * 3328 := by norm_cast + scalar_tac + have h := Int.ediv_le_ediv (a := -((2^16 : Int) * 3328)) (b := v) + (c := (2^16 : Int)) (by decide) hl + have h_const2 : (-((2^16 : Int) * 3328)) / (2^16 : Int) = -3328 := by decide + scalar_tac + · have hu : v ≤ (2^16 : Int) * 3328 := by + have h_nat : (v.natAbs : Int) ≤ ((2^16 * 3328 : Nat) : Int) := by exact_mod_cast hb + have h_abs : |v| = (v.natAbs : Int) := Int.abs_eq_natAbs v + have h_v_lt_abs : v ≤ |v| := le_abs_self v + have h2 : ((2^16 * 3328 : Nat) : Int) = (2^16 : Int) * 3328 := by norm_cast + scalar_tac + have h := Int.ediv_le_ediv (a := v) (b := (2^16 : Int) * 3328) + (c := (2^16 : Int)) (by decide) hu + have h_const : ((2^16 : Int) * 3328) / (2^16 : Int) = 3328 := by decide + scalar_tac + have h_km_div : -1665 ≤ k16 * 3329 / (2^16 : Int) ∧ k16 * 3329 / (2^16 : Int) ≤ 1664 := by + refine ⟨?_, ?_⟩ + · have hl : -(2^15 * 3329 : Int) ≤ k16 * 3329 := by + have h_lb := mul_le_mul_of_nonneg_right h_k16_bounds.1 (by decide : (0 : Int) ≤ 3329) + have : (-(2^15 : Int)) * 3329 = -(2^15 * 3329 : Int) := by ring + scalar_tac +nonLin + have h := Int.ediv_le_ediv (a := -(2^15 * 3329 : Int)) (b := k16 * 3329) + (c := (2^16 : Int)) (by decide) hl + have h_const : -(2^15 * 3329 : Int) / (2^16 : Int) = -1665 := by decide + scalar_tac + · have hu : k16 * 3329 ≤ ((2^15 - 1) * 3329 : Int) := by + have h_le : k16 ≤ 2^15 - 1 := by omega + exact mul_le_mul_of_nonneg_right h_le (by decide) + have h := Int.ediv_le_ediv (a := k16 * 3329) (b := ((2^15 - 1) * 3329 : Int)) + (c := (2^16 : Int)) (by decide) hu + have h_const : (((2^15 - 1) * 3329 : Int)) / (2^16 : Int) = 1664 := by decide + scalar_tac + -- Goal: `-(2^(16-1)) ≤ ↑value / 2^16 - k16 * 3329 / 2^16`. + -- Substitute `v = ↑value` so the named bounds in scope discharge it. + show -((2 : Int)^(16-1)) ≤ v / 2^16 - k16 * 3329 / 2^16 + have h_2_15 : ((2 : Int)^(16-1)) = ((2 : Int)^15) := by decide + rw [h_2_15] + have h_15_le : (-(2^15) : Int) ≤ -4993 := by decide + have hl1 : -3328 ≤ v / (2^16 : Int) := h_v_div.1 + have hl2 : k16 * 3329 / (2^16 : Int) ≤ 1664 := h_km_div.2 + -- Combine: v/R - km/R ≥ -3328 - 1664 = -4992 ≥ -4993 ≥ -2^15. + have h_sum : -3328 - 1664 ≤ v / 2^16 - k16 * 3329 / 2^16 := by + have := add_le_add hl1 (neg_le_neg hl2) + have h_simp : (-3328 : Int) + (-1664) = -3328 - 1664 := by ring + have h_simp2 : v / (2^16 : Int) + -(k16 * 3329 / (2^16 : Int)) + = v / (2^16 : Int) - k16 * 3329 / (2^16 : Int) := by ring + rw [h_simp] at this + rw [h_simp2] at this + exact this + have h_chain : (-(2^15) : Int) ≤ -3328 - 1664 := by decide + exact le_trans h_chain h_sum + · have h_v_div : -3328 ≤ v / (2^16 : Int) ∧ v / (2^16 : Int) ≤ 3328 := by + refine ⟨?_, ?_⟩ + · have hl : -((2^16 : Int) * 3328) ≤ v := by + have h_nat : (v.natAbs : Int) ≤ ((2^16 * 3328 : Nat) : Int) := by exact_mod_cast hb + have h_abs : |v| = (v.natAbs : Int) := Int.abs_eq_natAbs v + have h_v_lt_abs : -|v| ≤ v := neg_abs_le v + have h2 : ((2^16 * 3328 : Nat) : Int) = (2^16 : Int) * 3328 := by norm_cast + scalar_tac + have h := Int.ediv_le_ediv (a := -((2^16 : Int) * 3328)) (b := v) + (c := (2^16 : Int)) (by decide) hl + have h_const : (-((2^16 : Int) * 3328)) / (2^16 : Int) = -3328 := by decide + scalar_tac + · have hu : v ≤ (2^16 : Int) * 3328 := by + have h_nat : (v.natAbs : Int) ≤ ((2^16 * 3328 : Nat) : Int) := by exact_mod_cast hb + have h_abs : |v| = (v.natAbs : Int) := Int.abs_eq_natAbs v + have h_v_lt_abs : v ≤ |v| := le_abs_self v + have h2 : ((2^16 * 3328 : Nat) : Int) = (2^16 : Int) * 3328 := by norm_cast + scalar_tac + have h := Int.ediv_le_ediv (a := v) (b := (2^16 : Int) * 3328) + (c := (2^16 : Int)) (by decide) hu + have h_const : ((2^16 : Int) * 3328) / (2^16 : Int) = 3328 := by decide + scalar_tac + have h_km_div : -1665 ≤ k16 * 3329 / (2^16 : Int) ∧ k16 * 3329 / (2^16 : Int) ≤ 1664 := by + refine ⟨?_, ?_⟩ + · have hl : -(2^15 * 3329 : Int) ≤ k16 * 3329 := by + have h_lb := mul_le_mul_of_nonneg_right h_k16_bounds.1 (by decide : (0 : Int) ≤ 3329) + have : (-(2^15 : Int)) * 3329 = -(2^15 * 3329 : Int) := by ring + scalar_tac +nonLin + have h := Int.ediv_le_ediv (a := -(2^15 * 3329 : Int)) (b := k16 * 3329) + (c := (2^16 : Int)) (by decide) hl + have h_const : -(2^15 * 3329 : Int) / (2^16 : Int) = -1665 := by decide + scalar_tac + · have hu : k16 * 3329 ≤ ((2^15 - 1) * 3329 : Int) := by + have h_le : k16 ≤ 2^15 - 1 := by omega + exact mul_le_mul_of_nonneg_right h_le (by decide) + have h := Int.ediv_le_ediv (a := k16 * 3329) (b := ((2^15 - 1) * 3329 : Int)) + (c := (2^16 : Int)) (by decide) hu + have h_const : (((2^15 - 1) * 3329 : Int)) / (2^16 : Int) = 1664 := by decide + scalar_tac + -- Goal: `↑value / 2^16 - k16 * 3329 / 2^16 < 2^(16-1)`. + show v / 2^16 - k16 * 3329 / 2^16 < (2 : Int)^(16-1) + have h_2_15 : ((2 : Int)^(16-1)) = ((2 : Int)^15) := by decide + rw [h_2_15] + have hu1 : v / (2^16 : Int) ≤ 3328 := h_v_div.2 + have hl2 : -1665 ≤ k16 * 3329 / (2^16 : Int) := h_km_div.1 + have h_bound : (3328 + 1665 : Int) < 2^15 := by decide + -- v/R - km/R ≤ 3328 - (-1665) = 4993 < 2^15. + have h_sum : v / 2^16 - k16 * 3329 / 2^16 ≤ 3328 + 1665 := by + have := add_le_add hu1 (neg_le_neg hl2) + have h_simp : (3328 : Int) + (-(-1665)) = 3328 + 1665 := by ring + have h_simp2 : v / (2^16 : Int) + -(k16 * 3329 / (2^16 : Int)) + = v / (2^16 : Int) - k16 * 3329 / (2^16 : Int) := by ring + rw [h_simp] at this + rw [h_simp2] at this + exact this + exact lt_of_le_of_lt h_sum h_bound + +/-- **Tight bound for the conditional half of L0.3.** + + When `|value| ≤ 2^15 * 3328`, `(mont_reduce_impl_value value).val.natAbs ≤ 3328`. + + Triangle-inequality argument: since `v ≡ km (mod R)` (the + `mont_reduce_core` `h_km_mod` keystone), `res * R = v - km` *exactly*. + Hence `|res| * R = |v - km| ≤ |v| + |km| ≤ 2^15·3328 + 2^15·3329 = + 2^15·6657`, so `|res| ≤ 6657/2 = 3328` (Int division). -/ +private theorem mont_reduce_tight_3328 + (v : Std.I32) (h_v : v.val.natAbs ≤ 2^15 * 3328) : + (mont_reduce_impl_value v).val.natAbs ≤ 3328 := by + -- Loosen the precondition for `mont_reduce_impl_value_val`. + have h_loose : v.val.natAbs ≤ 2^16 * 3328 := + le_trans h_v (by decide : (2^15 * 3328 : Nat) ≤ 2^16 * 3328) + rw [mont_reduce_impl_value_val v h_loose] + set vi : Int := v.val with hvi_def + set v16 : Int := Int.bmod vi (2^16) with hv16_def + set k16 : Int := Int.bmod (v16 * 62209) (2^16) with hk16_def + set km : Int := k16 * 3329 with hkm_def + set res : Int := vi / (2^16 : Int) - km / (2^16 : Int) with hres_def + -- Bound on |vi| as an Int. + have h_vi_abs : |vi| ≤ ((2^15 : Int) * 3328) := by + rw [Int.abs_eq_natAbs] + have h_nat : (vi.natAbs : Int) ≤ ((2^15 * 3328 : Nat) : Int) := by exact_mod_cast h_v + have h_nat_int : ((2^15 * 3328 : Nat) : Int) = (2^15 : Int) * 3328 := by norm_cast + rw [← h_nat_int]; exact h_nat + -- Bound on |k16| as an Int (from bmod 2^16 bounds). + have h_k16_lb : -(2^15 : Int) ≤ k16 := (Arith.Int.bmod_pow2_bounds 16 (v16 * 62209)).1 + have h_k16_ub : k16 < (2^15 : Int) := (Arith.Int.bmod_pow2_bounds 16 (v16 * 62209)).2 + have h_k16_abs : |k16| ≤ (2^15 : Int) := abs_le.mpr ⟨h_k16_lb, le_of_lt h_k16_ub⟩ + -- Re-derive the keystone `(vi - km) % R = 0`. + have h_km_mod : (vi - km) % (2^16 : Int) = 0 := by + have h_keystone_int : (62209 * 3329 : Int) % (2^16) = 1 := by decide + have h_k16_emod : k16 % (2^16 : Int) = (v16 * 62209) % (2^16 : Int) := by + change (Int.bmod (v16 * 62209) (2^16)) % (2^16 : Int) = (v16 * 62209) % (2^16 : Int) + exact_mod_cast Int.bmod_emod + have h_step1 : km % (2^16 : Int) = (v16 * (62209 * 3329)) % (2^16 : Int) := by + change (k16 * 3329) % _ = _ + rw [Int.mul_emod, h_k16_emod, ← Int.mul_emod] + congr 1; ring + have h_step2 : km % (2^16 : Int) = v16 % (2^16 : Int) := by + rw [h_step1, Int.mul_emod, h_keystone_int, mul_one, Int.emod_emod_of_dvd _ (dvd_refl _)] + have h_v_emod : vi % (2^16 : Int) = v16 % (2^16 : Int) := by + change vi % (2^16 : Int) = (Int.bmod vi (2^16)) % (2^16 : Int) + exact_mod_cast Int.bmod_emod.symm + rw [Int.sub_emod, h_v_emod, ← h_step2]; simp + -- Key identity: `res * R = vi - km` exactly. + have h_res_R : res * (2^16 : Int) = vi - km := by + have h_R_dvd : (2^16 : Int) ∣ (vi - km) := Int.dvd_of_emod_eq_zero h_km_mod + obtain ⟨q, hq⟩ := h_R_dvd + have h_vm_div : (vi - km) / (2^16 : Int) = q := by + rw [hq]; exact Int.mul_ediv_cancel_left q (by decide) + have h_div_split : vi / (2^16 : Int) - km / (2^16 : Int) = (vi - km) / (2^16 : Int) := + libcrux_iot_ml_kem.Spec.Montgomery.sub_div_of_emod_eq_zero vi km (2^16) (by decide) h_km_mod + show (vi / (2^16 : Int) - km / (2^16 : Int)) * (2^16 : Int) = vi - km + rw [h_div_split, h_vm_div, hq]; ring + -- Triangle inequality + bounds: |vi - km| ≤ 2^15 * 6657. + have h_km_abs : |km| ≤ (2^15 : Int) * 3329 := by + show |k16 * 3329| ≤ _ + rw [abs_mul] + have h_3329 : |(3329 : Int)| = 3329 := by decide + rw [h_3329] + exact mul_le_mul_of_nonneg_right h_k16_abs (by decide) + have h_diff_abs : |vi - km| ≤ (2^15 : Int) * (3328 + 3329) := by + have h_tri : |vi - km| ≤ |vi| + |km| := abs_sub vi km + have h_sum : |vi| + |km| ≤ (2^15 : Int) * 3328 + (2^15 : Int) * 3329 := + add_le_add h_vi_abs h_km_abs + have h_factor : (2^15 : Int) * 3328 + (2^15 : Int) * 3329 = (2^15 : Int) * (3328 + 3329) := by + ring + rw [h_factor] at h_sum + exact le_trans h_tri h_sum + -- |res| * R ≤ 2^15 * 6657, hence |res| ≤ 3328. + have h_res_R_factor : |res * (2^16 : Int)| = |res| * (2^16 : Int) := by + rw [abs_mul, show |(2^16 : Int)| = (2^16 : Int) from by decide] + have h_res_R_ge : |res| * (2^16 : Int) ≤ (2^15 : Int) * (3328 + 3329) := by + rw [← h_res_R_factor, h_res_R]; exact h_diff_abs + have h_res_le_3328 : |res| ≤ (3328 : Int) := by + -- Suppose |res| ≥ 3329; then |res| * 2^16 ≥ 3329 * 2^16 > 2^15 * 6657. Contradiction. + by_contra h_gt + push Not at h_gt + have h_ge : (3329 : Int) ≤ |res| := h_gt + have h_mul : (3329 : Int) * (2^16 : Int) ≤ |res| * (2^16 : Int) := + mul_le_mul_of_nonneg_right h_ge (by decide) + have h_chain : (3329 : Int) * (2^16 : Int) ≤ (2^15 : Int) * (3328 + 3329) := + le_trans h_mul h_res_R_ge + have h_eval : ((2^15 : Int) * (3328 + 3329) < (3329 : Int) * (2^16 : Int)) := by decide + exact absurd (lt_of_le_of_lt h_chain h_eval) (lt_irrefl _) + have h_abs_eq : |res| = (res.natAbs : Int) := Int.abs_eq_natAbs res + rw [h_abs_eq] at h_res_le_3328 + exact_mod_cast h_res_le_3328 + +/-! ### L0.3 Triple. -/ + +@[spec] +theorem montgomery_reduce_element_spec + (value : Std.I32) (hb : value.val.natAbs ≤ 3328 * 2^16) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_reduce_element value + ⦃ ⇓ r => ⌜ r.val.natAbs ≤ 3328 + 1665 + ∧ (value.val.natAbs ≤ 3328 * 2^15 → r.val.natAbs ≤ 3328) + ∧ modq_eq r.val (value.val * 169) 3329 ⌝ ⦄ := by + -- Normalise to the form `mont_reduce_core` / `mont_reduce_impl_value_val` use. + have hb' : value.val.natAbs ≤ 2^16 * 3328 := by + have h_eq : (3328 * 2^16 : Nat) = 2^16 * 3328 := by decide + rw [← h_eq]; exact hb + apply triple_of_ok_l0 (v := mont_reduce_impl_value value) + (mont_reduce_element_eq_ok value) + -- Three conjuncts: weak bound, conditional tight bound, modq new form. + refine ⟨?_, ?_, ?_⟩ + · -- Weak bound: `(mont_reduce_impl_value value).val.natAbs ≤ 4993`. + rw [mont_reduce_impl_value_val value hb'] + exact (mont_reduce_core value.val hb').2 + · -- Conditional tight bound: derived from the same `mont_reduce_impl_value_val` + -- but under the stronger precondition `|value| ≤ 3328 * 2^15`. + intro h_tight + have h_tight' : value.val.natAbs ≤ 2^15 * 3328 := by + have h_eq : (3328 * 2^15 : Nat) = 2^15 * 3328 := by decide + rw [← h_eq]; exact h_tight + exact mont_reduce_tight_3328 value h_tight' + · -- New modq form: derived from the old one via `modq_R_to_169`. + rw [mont_reduce_impl_value_val value hb'] + exact modq_R_to_169 _ _ (mont_reduce_core value.val hb').1 + +/-! ## L0.4 — `montgomery_multiply_fe_by_fer_spec` + + Trivial corollary of L0.3: the impl is `montgomery_reduce_element + (I32.wrapping_mul (cast .I32 fe) (cast .I32 fer))`. The wrap-mul + is exact since `|fe·fer| ≤ 2^15·1664 < 2^31`. Reuses the L0.3 + privates `mont_reduce_impl_value` / `mont_reduce_impl_value_val` + / `mont_reduce_element_eq_ok` (same-file access) to derive the + **tight** `|r| ≤ 832 + 1665 = 2497 ≤ 3328` bound that L0.3's + public `@[spec]` (`≤ 4993`) does not expose directly. -/ + + +/-- Closed form of the do-block at the I32 product level: the impl + reduces to `mont_reduce_element` of the exact (non-wrapped) product. -/ +private theorem mmfbf_eq_ok (fe fer : Std.I16) : + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_fe_by_fer fe fer + = libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_reduce_element + (Aeneas.Std.I32.wrapping_mul + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 fe) + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 fer)) := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_fe_by_fer + simp only [libcrux_secrets.traits.Classify.Blanket.classify, + libcrux_secrets.traits.Declassify.Blanket.declassify, + libcrux_secrets.I16.Insts.Libcrux_secretsIntCastOps.as_i32, + CoreModels.core.num.I32.wrapping_mul, + rust_primitives.arithmetic.wrapping_mul_i32, + Aeneas.Std.bind_tc_ok, Aeneas.Std.lift] + +/-- Under `|fer| ≤ 1664`, the I32 product is exact (no wrap): its + `.val` is `fe.val * fer.val` (in `Int`). -/ +private theorem mmfbf_product_val + (fe fer : Std.I16) (hfer : fer.val.natAbs ≤ 1664) : + (Aeneas.Std.I32.wrapping_mul + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 fe) + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 fer)).val + = fe.val * fer.val := by + -- I16 bounds: |fe.val| < 2^15, |fer.val| < 2^15. + have h_fe_bounds := fe.hBounds + have h_fer_bounds := fer.hBounds + have h_red16 : (Aeneas.Std.IScalarTy.I16.numBits - 1) = 15 := by decide + rw [h_red16] at h_fe_bounds h_fer_bounds + -- (cast .I32 x).val = x.val (since |x.val| < 2^15 ≤ 2^31). + have h_fe_cast : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 fe).val = fe.val := by + apply Aeneas.Std.IScalar.val_mod_pow_inBounds + · have h_step : -((2 : Int) ^ (Aeneas.Std.IScalarTy.I32.numBits - 1)) + ≤ -((2 : Int) ^ 15) := by decide + exact le_trans h_step h_fe_bounds.1 + · have h_step : ((2 : Int) ^ 15) ≤ (2 : Int) ^ (Aeneas.Std.IScalarTy.I32.numBits - 1) := by + decide + exact lt_of_lt_of_le h_fe_bounds.2 h_step + have h_fer_cast : (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 fer).val = fer.val := by + apply Aeneas.Std.IScalar.val_mod_pow_inBounds + · have h_step : -((2 : Int) ^ (Aeneas.Std.IScalarTy.I32.numBits - 1)) + ≤ -((2 : Int) ^ 15) := by decide + exact le_trans h_step h_fer_bounds.1 + · have h_step : ((2 : Int) ^ 15) ≤ (2 : Int) ^ (Aeneas.Std.IScalarTy.I32.numBits - 1) := by + decide + exact lt_of_lt_of_le h_fer_bounds.2 h_step + -- (wrapping_mul a b).val = bmod (a.val * b.val) (2^32) = a.val * b.val + -- (since |a.val * b.val| ≤ 2^15 * 1664 < 2^31). + rw [Aeneas.Std.I32.wrapping_mul_val_eq, h_fe_cast, h_fer_cast] + -- |fe.val| < 2^15, |fer.val| ≤ 1664, so |fe.val * fer.val| ≤ 2^15 * 1664 < 2^31. + have h_fer_abs : |fer.val| ≤ (1664 : Int) := by + rw [Int.abs_eq_natAbs]; exact_mod_cast hfer + have h_fe_abs : |fe.val| ≤ ((2 : Int)^15) := by + rw [Int.abs_eq_natAbs] + have h_natAbs_int : (fe.val.natAbs : Int) ≤ ((2 : Int)^15) := by + rw [← Int.abs_eq_natAbs]; exact abs_le.mpr ⟨h_fe_bounds.1, le_of_lt h_fe_bounds.2⟩ + exact h_natAbs_int + have h_prod_abs : |fe.val * fer.val| ≤ ((2 : Int)^15) * 1664 := by + rw [abs_mul] + have h_nn : (0 : Int) ≤ |fer.val| := abs_nonneg _ + have h_pos : (0 : Int) ≤ ((2 : Int)^15) := by decide + calc |fe.val| * |fer.val| + ≤ ((2 : Int)^15) * |fer.val| := mul_le_mul_of_nonneg_right h_fe_abs h_nn + _ ≤ ((2 : Int)^15) * 1664 := mul_le_mul_of_nonneg_left h_fer_abs h_pos + have h_prod_lb : -((2 : Int)^15 * 1664) ≤ fe.val * fer.val := (abs_le.mp h_prod_abs).1 + have h_prod_ub : fe.val * fer.val ≤ ((2 : Int)^15 * 1664) := (abs_le.mp h_prod_abs).2 + apply Arith.Int.bmod_pow2_eq_of_inBounds' 32 _ (by decide) + · have h_const : -((2 : Int)^(32-1)) ≤ -((2 : Int)^15 * 1664) := by decide + exact le_trans h_const h_prod_lb + · have h_const : ((2 : Int)^15 * 1664) < ((2 : Int)^(32-1)) := by decide + exact lt_of_le_of_lt h_prod_ub h_const + +@[spec] +theorem montgomery_multiply_fe_by_fer_spec + (fe : Std.I16) (fer : Std.I16) (hfer : fer.val.natAbs ≤ 1664) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_fe_by_fer fe fer + ⦃ ⇓ r => ⌜ r.val.natAbs ≤ 3328 + ∧ modq_eq r.val (fe.val * fer.val * 169) 3329 ⌝ ⦄ := by + -- Reduce L0.4's impl to a single `montgomery_reduce_element` call on the exact product. + set product : Std.I32 := + Aeneas.Std.I32.wrapping_mul + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 fe) + (Aeneas.Std.IScalar.cast Aeneas.Std.IScalarTy.I32 fer) + have h_product_val : product.val = fe.val * fer.val := mmfbf_product_val fe fer hfer + -- Bound the product: |fe·fer| ≤ 2^15 · 1664. + have h_product_natAbs : product.val.natAbs ≤ 2^15 * 1664 := by + rw [h_product_val] + have h_fe_bounds := fe.hBounds + have h_red : (Aeneas.Std.IScalarTy.I16.numBits - 1) = 15 := by decide + have h_fe_lb : -((2 : Int)^15) ≤ fe.val := by + have := h_fe_bounds.1; rw [h_red] at this; exact this + have h_fe_ub : fe.val < ((2 : Int)^15) := by + have := h_fe_bounds.2; rw [h_red] at this; exact this + have h_fe_abs : (fe.val.natAbs : Int) ≤ ((2 : Int)^15) := by + rw [← Int.abs_eq_natAbs]; exact abs_le.mpr ⟨h_fe_lb, le_of_lt h_fe_ub⟩ + rw [Int.natAbs_mul] + have h_nat_fe : fe.val.natAbs ≤ 2^15 := by exact_mod_cast h_fe_abs + exact Nat.mul_le_mul h_nat_fe hfer + -- Preconditions for L0.3: + -- weak: product.val.natAbs ≤ 3328 * 2^16 (always true here). + -- tight: product.val.natAbs ≤ 3328 * 2^15 (used to extract the |r| ≤ 3328 bound). + have h_pre_weak : product.val.natAbs ≤ 3328 * 2^16 := by + have h_step : (2^15 * 1664 : Nat) ≤ (3328 * 2^16 : Nat) := by decide + exact le_trans h_product_natAbs h_step + have h_pre_tight : product.val.natAbs ≤ 3328 * 2^15 := by + have h_step : (2^15 * 1664 : Nat) ≤ (3328 * 2^15 : Nat) := by decide + exact le_trans h_product_natAbs h_step + -- Extract L0.3's conclusion via `triple_exists_ok_l0` (depending only on L0.3's + -- public `@[spec]`, never reaching into L0.3 privates). + obtain ⟨r0, h_eq_ok, _h_weak, h_cond, h_modq_new⟩ := + triple_exists_ok_l0 (montgomery_reduce_element_spec product h_pre_weak) + -- The L0.4 impl reduces to .ok r0; close via triple_of_ok_l0. + apply triple_of_ok_l0 (v := r0) (by rw [mmfbf_eq_ok]; exact h_eq_ok) + refine ⟨?_, ?_⟩ + · -- Tight bound: feed the antecedent to L0.3's conditional post. + exact h_cond h_pre_tight + · -- modq_new: rewrite product.val to fe.val * fer.val. + rw [← h_product_val]; exact h_modq_new + +end libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement +/-! ### Extracted from FCTargets.lean (§vector_arith_lo). -/ + +namespace libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement +open libcrux_iot_ml_kem.Spec.Lift +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §L0 — FE scalar primitives (4 theorems). + + Each post pairs the existing bounds conjunct (load-bearing for + callers) with the FC equation against the spec-level pure op. -/ + +/-- The Triple `⦃True⦄ x ⦃⇓ r => ⌜P r⌝⦄` closer for `x = .ok v`. + Lifts a pure-Prop fact about the value into a Triple post. + Mirror of SKILL §13.5 helper, scoped to this file. -/ +theorem triple_of_ok_fc {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +/-- Extract the `.ok` witness from a true-pre Triple. + Mirror of SKILL §13.5 helper, scoped to this file. -/ +theorem triple_exists_ok_fc {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-- `.val`-preserving `Std.Usize` add helper, scoped to this file. + Mirrors `libcrux_iot_ml_kem.Polynomial.NttDrivers.usize_add_ok_eq` + (private to L3_NTTDrivers). -/ +theorem usize_add_ok_eq_fc (x y : Std.Usize) + (h_max : x.val + y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x + y : Result Std.Usize) = .ok z ∧ z.val = x.val + y.val := by + have hT := Std.Usize.add_spec h_max + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v⟩ + +/-- `.val`-preserving `Std.Usize` mul helper. -/ +theorem usize_mul_ok_eq_fc (x y : Std.Usize) + (h_max : x.val * y.val ≤ Std.Usize.max) : + ∃ z : Std.Usize, (x * y : Result Std.Usize) = .ok z ∧ z.val = x.val * y.val := by + have hT := Std.Usize.mul_spec h_max + obtain ⟨z, h_eq, h_v⟩ := Std.WP.spec_imp_exists hT + exact ⟨z, h_eq, h_v⟩ + +/-! ### L0.1 — `get_n_least_significant_bits`. + Impl computes `value & ((1 <<< n) - 1)`; the spec + `Spec.get_n_least_significant_bits_pure` is precisely that BV-mask + expression (see §0.5 above). The post-shape is `bounds ∧ r = spec`. + + Proof sketch: + 1. Pure-projection side lemma `get_n_least_significant_bits_eq_ok_fc` + reduces the impl `do`-block to `.ok (Spec.<…>_pure n value)` by + `unfold ; simp only [shift_left_lemmas, wrapping_sub_u32, bind_tc_ok] ; rfl`. + The precondition `n.val ≤ 16` discharges the `n < 32` shift bound. + 2. Apply `triple_of_ok_fc` with the side lemma to discharge the + monadic shell. + 3. The FC equality is `rfl` (the spec body IS the mask expression). + 4. The bound `r.val < 2^n.val` reduces to + `(value.bv &&& mask).toNat < 2^n.val` via `BitVec.toNat_and` + + `libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks.mask_pow2_minus_one_toNat` + `Nat.and_le_right` + `omega`. -/ + +/-- Pure-projection side lemma for `get_n_least_significant_bits`. + Pins the impl's `.ok` value to `Spec.get_n_least_significant_bits_pure`. -/ +theorem get_n_least_significant_bits_eq_ok_fc + (n : Std.U8) (value : Std.U32) (hn : n.val ≤ 16) : + libcrux_iot_ml_kem.vector.portable.arithmetic.get_n_least_significant_bits n value + = .ok (Spec.get_n_least_significant_bits_pure n value) := by + unfold libcrux_iot_ml_kem.vector.portable.arithmetic.get_n_least_significant_bits + unfold Spec.get_n_least_significant_bits_pure + have hn_lt : n.val < Aeneas.Std.UScalarTy.U32.numBits := by + have h_red : (Aeneas.Std.UScalarTy.U32.numBits : Nat) = 32 := by decide + rw [h_red]; omega + simp only [HShiftLeft.hShiftLeft, Aeneas.Std.UScalar.shiftLeft_UScalar, + Aeneas.Std.UScalar.shiftLeft, hn_lt, reduceIte, + CoreModels.core.num.U32.wrapping_sub, + rust_primitives.arithmetic.wrapping_sub_u32, + Aeneas.Std.bind_tc_ok] + rfl + +/-- L0.1 — `get_n_least_significant_bits`. + Spec: bitwise AND with `(1 << n) - 1`. -/ +@[spec high] +theorem get_n_least_significant_bits_fc + (n : Std.U8) (value : Std.U32) (hn : n.val ≤ 16) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.get_n_least_significant_bits n value + ⦃ ⇓ r => ⌜ r.val < 2 ^ n.val + ∧ r = Spec.get_n_least_significant_bits_pure n value ⌝ ⦄ := by + apply triple_of_ok_fc (v := Spec.get_n_least_significant_bits_pure n value) + (get_n_least_significant_bits_eq_ok_fc n value hn) + refine ⟨?_, rfl⟩ + unfold Spec.get_n_least_significant_bits_pure + show (value.bv &&& ((1#32 <<< n.val) - 1#32)).toNat < 2 ^ n.val + rw [BitVec.toNat_and] + have h_mask_toNat : ((1#32 <<< n.val) - 1#32).toNat = 2 ^ n.val - 1 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks.mask_pow2_minus_one_toNat n.val hn + rw [h_mask_toNat] + have h_and_le : value.bv.toNat &&& (2 ^ n.val - 1) ≤ 2 ^ n.val - 1 := Nat.and_le_right + have h_pos : 0 < (2 : Nat) ^ n.val := Nat.two_pow_pos _ + omega + +/-! ### L0.2 — `barrett_reduce_element`. + + Proof sketch: + 1. `Spec.barrett_pure` is defined as the canonical round-trip + `feOfZMod ∘ zmodOfFE`. Helper `barrett_pure_lift_fe` shows that on + `lift_fe`-image FEs (which are canonical by construction) this is + the identity, so `Spec.barrett_pure (lift_fe value) = lift_fe value`. + 2. The legacy `libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.barrett_reduce_element_spec` (bounds-only) + gives `modq_eq r.val value.val 3329 ∧ r.val.natAbs ≤ 3328`. We + consume it via `triple_exists_ok_fc`; we only need its content, + not its `@[spec]` registration. + 3. `modq_eq_cast_zmod` translates `modq_eq r.val value.val 3329` to + `(r.val : ZMod 3329) = (value.val : ZMod 3329)` via + `ZMod.intCast_zmod_eq_zero_iff_dvd`. + 4. Conclude `lift_fe r = lift_fe value` by `congr 1`. -/ + +/-- The canonical round-trip is the identity on lift_fe images. -/ +theorem barrett_pure_lift_fe (x : Std.I16) : + Spec.barrett_pure (lift_fe x) = lift_fe x := by + unfold Spec.barrett_pure lift_fe + congr 1 + exact zmodOfFE_feOfZMod _ + +/-- Cast `modq_eq` into a `ZMod 3329` equality. The barrier-side + `libcrux_iot_ml_kem.Spec.ModularArith.modq_eq` unfolds to `(a - b) % 3329 = 0`; via + `ZMod.intCast_zmod_eq_zero_iff_dvd` and `push_cast` this becomes + `(a : ZMod 3329) - (b : ZMod 3329) = 0`. -/ +theorem modq_eq_cast_zmod (a b : Int) + (h : libcrux_iot_ml_kem.Spec.ModularArith.modq_eq a b 3329) : + (a : ZMod 3329) = (b : ZMod 3329) := by + unfold libcrux_iot_ml_kem.Spec.ModularArith.modq_eq at h + have hdvd : (3329 : Int) ∣ (a - b) := Int.dvd_of_emod_eq_zero h + have hzero : ((a - b : Int) : ZMod 3329) = 0 := + (ZMod.intCast_zmod_eq_zero_iff_dvd (a - b) 3329).mpr (by exact_mod_cast hdvd) + push_cast at hzero + exact sub_eq_zero.mp hzero + +/-- Bridge lemma: `lift_fe a = lift_fe b` from `modq_eq a.val b.val 3329`. + Since `lift_fe x = feOfZMod ((x.val : Int) : ZMod 3329)`, the equality + reduces (via `congr 1`) to the `ZMod 3329` cast equality delivered by + `modq_eq_cast_zmod`. Pure-projection side lemma. -/ +theorem lift_fe_eq_of_modq (a b : Std.I16) + (h : libcrux_iot_ml_kem.Spec.ModularArith.modq_eq a.val b.val 3329) : + lift_fe a = lift_fe b := by + unfold lift_fe i16_to_spec_fe_plain + congr 1 + exact modq_eq_cast_zmod _ _ h + +/-- L0.2 — `barrett_reduce_element`. + Spec: canonical residue mod q via `FieldElement.new (x % q)`. -/ +@[spec high] +theorem barrett_reduce_element_fc + (value : Std.I16) (hb : value.val.natAbs ≤ 32767) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.barrett_reduce_element value + ⦃ ⇓ r => ⌜ r.val.natAbs ≤ 3328 + ∧ lift_fe r = Spec.barrett_pure (lift_fe value) ⌝ ⦄ := by + have h_legacy := libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.barrett_reduce_element_spec value hb + obtain ⟨r0, h_eq, h_modq, h_bnd⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + refine ⟨h_bnd, ?_⟩ + rw [barrett_pure_lift_fe] + unfold lift_fe + congr 1 + show (r0.val : ZMod 3329) = (value.val : ZMod 3329) + exact modq_eq_cast_zmod _ _ h_modq + +/-! ### L0.3 — `montgomery_reduce_element`. + + Proof sketch: + 1. `Spec.mont_reduce_pure x := feOfZMod (zmodOfFE x · 169 · 169)`. + Helper `mont_reduce_pure_lift_fe_int` unfolds this composed with + `lift_fe_int v` to `feOfZMod ((v : ZMod 3329) · 169 · 169)`. + 2. Legacy `libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_reduce_element_spec` gives + `r.val.natAbs ≤ 3328 + 1665 ∧ (tight-bound conditional) + ∧ modq_eq r.val (value.val * 169) 3329`. We extract via + `triple_exists_ok_fc` and drop the tight-bound conditional clause. + 3. Translate `modq_eq r.val (value.val * 169) 3329` to a ZMod equality + `(r.val : ZMod 3329) = (value.val * 169 : ZMod 3329)` via + `modq_eq_cast_zmod`. + 4. Unfold `lift_fe_mont` and `i16_to_spec_fe_mont`, then `congr 1` + reduces the goal to a ZMod equation closed by the step-3 hypothesis + plus `push_cast`. -/ + +/-- Helper: `mont_reduce_pure` composed with `lift_fe_int` simplifies. -/ +theorem mont_reduce_pure_lift_fe_int (v : Int) : + Spec.mont_reduce_pure (lift_fe_int v) = feOfZMod ((v : ZMod 3329) * 169 * 169) := by + unfold Spec.mont_reduce_pure lift_fe_int + rw [zmodOfFE_feOfZMod] + +/-- L0.3 — `montgomery_reduce_element`. + Spec: strip TWO Mont factors (the impl's R⁻¹ + the `lift_fe_mont` + R-stripping). See the `Spec.mont_reduce_pure` docstring for the + derivation of `· 169²`. -/ +@[spec high] +theorem montgomery_reduce_element_fc + (value : Std.I32) (hv : value.val.natAbs ≤ 2^16 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_reduce_element value + ⦃ ⇓ r => ⌜ r.val.natAbs ≤ 3328 + 1665 + ∧ lift_fe_mont r = Spec.mont_reduce_pure (lift_fe_int value.val) ⌝ ⦄ := by + have hv' : value.val.natAbs ≤ 3328 * 2^16 := by + have h_eq : (3328 * 2^16 : Nat) = 2^16 * 3328 := by decide + rw [h_eq]; exact hv + have h_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_reduce_element_spec value hv' + obtain ⟨r0, h_eq, h_bnd, _h_tight, h_modq⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + refine ⟨h_bnd, ?_⟩ + rw [mont_reduce_pure_lift_fe_int] + unfold lift_fe_mont i16_to_spec_fe_mont + congr 1 + have h_zmod : ((r0.val : Int) : ZMod 3329) = ((value.val * 169 : Int) : ZMod 3329) := + modq_eq_cast_zmod _ _ h_modq + push_cast at h_zmod + rw [h_zmod] + +/-! ### L0.4 — `montgomery_multiply_fe_by_fer`. + + Proof sketch: + 1. Helper `mmfbf_pure_lift_fe_lift_fe_mont` unfolds + `Spec.montgomery_multiply_fe_by_fer_pure (lift_fe fe) (lift_fe_mont fer)` + to `feOfZMod ((fe.val : ZMod 3329) * ((fer.val : ZMod 3329) * 169) * 169)` + via `zmodOfFE_feOfZMod` (applied twice). + 2. Legacy `libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_multiply_fe_by_fer_spec` gives + `r.val.natAbs ≤ 3328 ∧ modq_eq r.val (fe.val * fer.val * 169) 3329`. + Note the legacy bound is TIGHTER than our locked post (3328 vs + 3328 + 1665), so the bound conjunct closes by transitivity: + `exact le_trans h_bnd_tight (by decide)`. + 3. Translate `modq_eq` to a ZMod equation via `modq_eq_cast_zmod`. + 4. Unfold `lift_fe_mont`/`i16_to_spec_fe_mont`, `congr 1` reduces to a + ZMod equation closed by the modq cast + `ring`. -/ + +/-- Helper: `Spec.montgomery_multiply_fe_by_fer_pure` composed with the + lifts simplifies via `zmodOfFE_feOfZMod`. -/ +theorem mmfbf_pure_lift_fe_lift_fe_mont (fe fer : Std.I16) : + Spec.montgomery_multiply_fe_by_fer_pure (lift_fe fe) (lift_fe_mont fer) + = feOfZMod ((fe.val : ZMod 3329) * ((fer.val : ZMod 3329) * 169) * 169) := by + unfold Spec.montgomery_multiply_fe_by_fer_pure lift_fe lift_fe_mont + i16_to_spec_fe_plain i16_to_spec_fe_mont + rw [zmodOfFE_feOfZMod, zmodOfFE_feOfZMod] + +/-- L0.4 — `montgomery_multiply_fe_by_fer`. + Spec: `fe · c` (where `fer = c · R`), encoded via `· R⁻¹` in canonical. -/ +@[spec high] +theorem montgomery_multiply_fe_by_fer_fc + (fe fer : Std.I16) + (hfe : fe.val.natAbs ≤ 32767) (hfer : fer.val.natAbs ≤ 1664) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.arithmetic.montgomery_multiply_fe_by_fer fe fer + ⦃ ⇓ r => ⌜ r.val.natAbs ≤ 3328 + 1665 + ∧ lift_fe_mont r + = Spec.montgomery_multiply_fe_by_fer_pure + (lift_fe fe) (lift_fe_mont fer) ⌝ ⦄ := by + have h_legacy := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_multiply_fe_by_fer_spec fe fer hfer + obtain ⟨r0, h_eq, h_bnd_tight, h_modq⟩ := triple_exists_ok_fc h_legacy + apply triple_of_ok_fc (v := r0) h_eq + refine ⟨?_, ?_⟩ + · -- Weaken legacy ≤ 3328 to locked-post ≤ 3328 + 1665. + exact le_trans h_bnd_tight (by decide) + · rw [mmfbf_pure_lift_fe_lift_fe_mont] + unfold lift_fe_mont i16_to_spec_fe_mont + congr 1 + have h_zmod : ((r0.val : Int) : ZMod 3329) + = ((fe.val * fer.val * 169 : Int) : ZMod 3329) := + modq_eq_cast_zmod _ _ h_modq + push_cast at h_zmod + rw [h_zmod] + ring + + +end libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement \ No newline at end of file diff --git a/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Ntt.lean b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Ntt.lean new file mode 100644 index 00000000..aa36c78d --- /dev/null +++ b/libcrux-iot/ml-kem/proofs/aeneas-lean/LibcruxIotMlKem/Vector/Portable/Ntt.lean @@ -0,0 +1,5701 @@ +/- + # `Equivalence/L2_NTTSteps.lean` — Layer 2 intra-PortableVector NTT step Triples. + + L2.x Triples for the per-PortableVector NTT butterflies in + `vector/portable/ntt.rs`. L2.1 `ntt_step_spec` is the single + Cooley-Tukey butterfly inside a PortableVector at indices `(i, j)` + with ζ; post is bound-only (unchanged lanes + 3·3328 → 4·3328 + propagation). The modular congruence content (`r[i] ≡ a + ζb`, + `r[j] ≡ a - ζb` mod 3329) lives in `BitMlKem.Commute`. +-/ +import LibcruxIotMlKem.Extraction.Funs +import LibcruxIotMlKem.Vector.Portable.Arithmetic.LoopHelper +import LibcruxIotMlKem.Vector.Portable.Arithmetic.PerElement +import LibcruxIotMlKem.Spec.Lift +import LibcruxIotMlKem.Vector.Portable.Arithmetic.Element + +set_option mvcgen.warning false +set_option linter.unusedVariables false +set_option linter.unusedSectionVars false + +namespace libcrux_iot_ml_kem.Vector.Portable.Ntt +open libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec.ModularArith libcrux_iot_ml_kem.Spec.Montgomery libcrux_iot_ml_kem.Spec.NumericKeystones libcrux_iot_ml_kem.Util.CreateI libcrux_iot_ml_kem.Util.LoopSpecs libcrux_iot_ml_kem.Util.SliceSpecs libcrux_iot_ml_kem.Vector.Portable.Arithmetic.BvMasks libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper + +/-! ## Local helpers — Triple ↔ Result.ok bridges. -/ + +/-- The Triple `⦃True⦄ x ⦃⇓ r => ⌜P r⌝⦄` closer for `x = .ok v`. -/ +private theorem triple_of_ok_l2 {α : Type} {x : Result α} {v : α} + {P : α → Prop} (hx : x = .ok v) (hp : P v) : + ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄ := by + subst hx; simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply, hp] + +/-- Extract the `.ok` witness from a true-pre Triple. Used by L2.1 to + consume L0.4's `@[spec]`. -/ +private theorem triple_exists_ok_l2 {α : Type} {x : Result α} {P : α → Prop} + (h : ⦃ ⌜ True ⌝ ⦄ x ⦃ ⇓ r => ⌜ P r ⌝ ⦄) : + ∃ v, x = .ok v ∧ P v := by + match hx : x with + | .ok v => exact ⟨v, rfl, (by subst hx; simpa [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply] using h)⟩ + | .fail _ => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + | .div => exact absurd h (by simp [Std.Do.Triple, WP.wp, PostCond.noThrow, PredTrans.apply]) + +/-! ## L2.1 — `ntt_step_spec` + + The impl is a straight chain of 8 binds: + 1. read `vec[j]` + 2. classify ζ (identity) + 3. L0.4 montgomery_multiply: `t = mont_mult(vec[j], ζ)` + 4. read `vec[i]` + 5. `a_minus_t = vec[i] - t` (wrapping_sub) + 6. `a_plus_t = vec[i] + t` (wrapping_add) + 7. write `vec[j] := a_minus_t` + 8. write `vec[i] := a_plus_t` + + Under the L0.4 bound (`|t.val| ≤ 3328`), and `|vec[i]|, |vec[j]| + ≤ 3·3328`, the two sums `vec[i] ± t` satisfy `| · | ≤ 4·3328 = + 13312 < 2^15 = 32768`, so the wrappings are the identity and + `(vec[i] ± t).val = vec[i].val ± t.val`. -/ + +/-- Reduction of `core.num.I16.wrapping_sub` to the underlying + Aeneas `Std.I16.wrapping_sub`. -/ +private theorem cm_wrapping_sub_ok_eq (x y : Std.I16) : + CoreModels.core.num.I16.wrapping_sub x y = .ok (Std.I16.wrapping_sub x y) := by + unfold CoreModels.core.num.I16.wrapping_sub + unfold rust_primitives.arithmetic.wrapping_sub_i16 + rfl + +/-- Reduction of `core.num.I16.wrapping_add` to the underlying + Aeneas `Std.I16.wrapping_add`. -/ +private theorem cm_wrapping_add_ok_eq (x y : Std.I16) : + CoreModels.core.num.I16.wrapping_add x y = .ok (Std.I16.wrapping_add x y) := by + unfold CoreModels.core.num.I16.wrapping_add + unfold rust_primitives.arithmetic.wrapping_add_i16 + rfl + +/-- Reduction of `classify` to identity. -/ +private theorem classify_ok_eq {T : Type} (x : T) : + libcrux_secrets.traits.Classify.Blanket.classify x = .ok x := rfl + +/-- Under `|a.val| ≤ B·3328`, `|t.val| ≤ 3328`, and `B ≤ 8`, the I16-wrapped + sum `a + t` has `.val = a.val + t.val` and `.val.natAbs ≤ (B+1)·3328`. -/ +private theorem add_no_overflow_value_B (a t : Std.I16) (B : Nat) + (h_a : a.val.natAbs ≤ B * 3328) (h_t : t.val.natAbs ≤ 3328) (h_B : B ≤ 8) : + (Std.I16.wrapping_add a t).val = a.val + t.val + ∧ (Std.I16.wrapping_add a t).val.natAbs ≤ (B + 1) * 3328 := by + -- |a + t| ≤ |a| + |t| ≤ B·3328 + 3328 = (B+1)·3328 ≤ 9·3328 = 29952 < 2^15 = 32768. + have h_sum_abs : ((a.val + t.val : Int)).natAbs ≤ (B + 1) * 3328 := by + have h_tri : (a.val + t.val).natAbs ≤ a.val.natAbs + t.val.natAbs := Int.natAbs_add_le _ _ + omega + -- No-overflow ⇒ bmod is identity. + have h_lb : -(2 ^ 15 : Int) ≤ a.val + t.val := by + have h_natAbs : ((a.val + t.val : Int)).natAbs ≤ (B + 1) * 3328 := h_sum_abs + have h_bound : (B + 1) * 3328 ≤ 9 * 3328 := by + have : B + 1 ≤ 9 := by omega + exact Nat.mul_le_mul_right _ this + omega + have h_ub : a.val + t.val < (2 ^ 15 : Int) := by + have h_natAbs : ((a.val + t.val : Int)).natAbs ≤ (B + 1) * 3328 := h_sum_abs + have h_bound : (B + 1) * 3328 ≤ 9 * 3328 := by + have : B + 1 ≤ 9 := by omega + exact Nat.mul_le_mul_right _ this + omega + have h_bmod : Int.bmod (a.val + t.val) (2 ^ 16) = a.val + t.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_add_val_eq a t + refine ⟨?_, ?_⟩ + · rw [h_val, h_bmod] + · rw [h_val, h_bmod]; exact h_sum_abs + +/-- Under `|a.val| ≤ B·3328`, `|t.val| ≤ 3328`, and `B ≤ 8`, the I16-wrapped + diff `a - t` has `.val = a.val - t.val` and `.val.natAbs ≤ (B+1)·3328`. -/ +private theorem sub_no_overflow_value_B (a t : Std.I16) (B : Nat) + (h_a : a.val.natAbs ≤ B * 3328) (h_t : t.val.natAbs ≤ 3328) (h_B : B ≤ 8) : + (Std.I16.wrapping_sub a t).val = a.val - t.val + ∧ (Std.I16.wrapping_sub a t).val.natAbs ≤ (B + 1) * 3328 := by + have h_diff_abs : ((a.val - t.val : Int)).natAbs ≤ (B + 1) * 3328 := by + have h_neg_natAbs : (-t.val).natAbs = t.val.natAbs := Int.natAbs_neg _ + have h_eq : a.val - t.val = a.val + (-t.val) := by ring + rw [h_eq] + have h_tri : (a.val + (-t.val)).natAbs ≤ a.val.natAbs + (-t.val).natAbs := + Int.natAbs_add_le _ _ + rw [h_neg_natAbs] at h_tri + omega + have h_lb : -(2 ^ 15 : Int) ≤ a.val - t.val := by + have h_natAbs : ((a.val - t.val : Int)).natAbs ≤ (B + 1) * 3328 := h_diff_abs + have h_bound : (B + 1) * 3328 ≤ 9 * 3328 := by + have : B + 1 ≤ 9 := by omega + exact Nat.mul_le_mul_right _ this + omega + have h_ub : a.val - t.val < (2 ^ 15 : Int) := by + have h_natAbs : ((a.val - t.val : Int)).natAbs ≤ (B + 1) * 3328 := h_diff_abs + have h_bound : (B + 1) * 3328 ≤ 9 * 3328 := by + have : B + 1 ≤ 9 := by omega + exact Nat.mul_le_mul_right _ this + omega + have h_bmod : Int.bmod (a.val - t.val) (2 ^ 16) = a.val - t.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_sub_val_eq a t + refine ⟨?_, ?_⟩ + · rw [h_val, h_bmod] + · rw [h_val, h_bmod]; exact h_diff_abs + +/-- Specialised form (B = 3) — preserves the original signature for callers. -/ +private theorem add_no_overflow_value (a t : Std.I16) + (h_a : a.val.natAbs ≤ 3 * 3328) (h_t : t.val.natAbs ≤ 3328) : + (Std.I16.wrapping_add a t).val = a.val + t.val + ∧ (Std.I16.wrapping_add a t).val.natAbs ≤ 4 * 3328 := + add_no_overflow_value_B a t 3 h_a h_t (by decide) + +/-- Specialised form (B = 3) — preserves the original signature for callers. -/ +private theorem sub_no_overflow_value (a t : Std.I16) + (h_a : a.val.natAbs ≤ 3 * 3328) (h_t : t.val.natAbs ≤ 3328) : + (Std.I16.wrapping_sub a t).val = a.val - t.val + ∧ (Std.I16.wrapping_sub a t).val.natAbs ≤ 4 * 3328 := + sub_no_overflow_value_B a t 3 h_a h_t (by decide) + +/-! ### Truly bnd-parameterised no-overflow lemmas. + + `add_no_overflow_value_B` / `sub_no_overflow_value_B` are convenient when + the lane bound has the multiplicative shape `B * 3328`. The L3.{1,2,3}_B + parameterisation upstream needs an arbitrary `Nat` bound. The + `_bnd` variants below replace `(B + 1) * 3328` with `bnd + 3328` and the + `B + 1 ≤ 9` bridge with `bnd + 3328 ≤ 32767` directly (equivalent to + `bnd ≤ 29439`). Numerically `29439 + 3328 = 32767 < 2^15`, so the I16 + wrapping is the identity. -/ + +/-- Under `|a.val| ≤ bnd`, `|t.val| ≤ 3328`, and `bnd ≤ 29439`, the I16-wrapped + sum `a + t` has `.val = a.val + t.val` and `.val.natAbs ≤ bnd + 3328`. -/ +private theorem add_no_overflow_value_bnd (a t : Std.I16) (bnd : Nat) + (h_a : a.val.natAbs ≤ bnd) (h_t : t.val.natAbs ≤ 3328) (h_bnd : bnd ≤ 29439) : + (Std.I16.wrapping_add a t).val = a.val + t.val + ∧ (Std.I16.wrapping_add a t).val.natAbs ≤ bnd + 3328 := by + -- |a + t| ≤ |a| + |t| ≤ bnd + 3328 ≤ 29439 + 3328 = 32767 < 2^15. + have h_sum_abs : ((a.val + t.val : Int)).natAbs ≤ bnd + 3328 := by + have h_tri : (a.val + t.val).natAbs ≤ a.val.natAbs + t.val.natAbs := Int.natAbs_add_le _ _ + omega + -- No-overflow ⇒ bmod is identity. + have h_lb : -(2 ^ 15 : Int) ≤ a.val + t.val := by + have h_natAbs : ((a.val + t.val : Int)).natAbs ≤ bnd + 3328 := h_sum_abs + have h_bound : bnd + 3328 ≤ 32767 := by omega + omega + have h_ub : a.val + t.val < (2 ^ 15 : Int) := by + have h_natAbs : ((a.val + t.val : Int)).natAbs ≤ bnd + 3328 := h_sum_abs + have h_bound : bnd + 3328 ≤ 32767 := by omega + omega + have h_bmod : Int.bmod (a.val + t.val) (2 ^ 16) = a.val + t.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_add_val_eq a t + refine ⟨?_, ?_⟩ + · rw [h_val, h_bmod] + · rw [h_val, h_bmod]; exact h_sum_abs + +/-- Under `|a.val| ≤ bnd`, `|t.val| ≤ 3328`, and `bnd ≤ 29439`, the I16-wrapped + diff `a - t` has `.val = a.val - t.val` and `.val.natAbs ≤ bnd + 3328`. -/ +private theorem sub_no_overflow_value_bnd (a t : Std.I16) (bnd : Nat) + (h_a : a.val.natAbs ≤ bnd) (h_t : t.val.natAbs ≤ 3328) (h_bnd : bnd ≤ 29439) : + (Std.I16.wrapping_sub a t).val = a.val - t.val + ∧ (Std.I16.wrapping_sub a t).val.natAbs ≤ bnd + 3328 := by + have h_diff_abs : ((a.val - t.val : Int)).natAbs ≤ bnd + 3328 := by + have h_neg_natAbs : (-t.val).natAbs = t.val.natAbs := Int.natAbs_neg _ + have h_eq : a.val - t.val = a.val + (-t.val) := by ring + rw [h_eq] + have h_tri : (a.val + (-t.val)).natAbs ≤ a.val.natAbs + (-t.val).natAbs := + Int.natAbs_add_le _ _ + rw [h_neg_natAbs] at h_tri + omega + have h_lb : -(2 ^ 15 : Int) ≤ a.val - t.val := by + have h_natAbs : ((a.val - t.val : Int)).natAbs ≤ bnd + 3328 := h_diff_abs + have h_bound : bnd + 3328 ≤ 32767 := by omega + omega + have h_ub : a.val - t.val < (2 ^ 15 : Int) := by + have h_natAbs : ((a.val - t.val : Int)).natAbs ≤ bnd + 3328 := h_diff_abs + have h_bound : bnd + 3328 ≤ 32767 := by omega + omega + have h_bmod : Int.bmod (a.val - t.val) (2 ^ 16) = a.val - t.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_sub_val_eq a t + refine ⟨?_, ?_⟩ + · rw [h_val, h_bmod] + · rw [h_val, h_bmod]; exact h_diff_abs + +@[spec] +theorem ntt_step_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i j : Std.Usize) + (h_i : i.val < 16) (h_j : j.val < 16) (h_ne : i.val ≠ j.val) + (h_zeta : zeta.val.natAbs ≤ 1664) + (h_a : (vec.elements.val[i.val]!).val.natAbs ≤ 3 * 3328) + (h_b : (vec.elements.val[j.val]!).val.natAbs ≤ 3 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j + ⦃ ⇓ r => ⌜ (∀ k : Nat, k < 16 → k ≠ i.val → k ≠ j.val → + r.elements.val[k]! = vec.elements.val[k]!) + ∧ (r.elements.val[i.val]!).val.natAbs ≤ 4 * 3328 + ∧ (r.elements.val[j.val]!).val.natAbs ≤ 4 * 3328 ⌝ ⦄ := by + have h_vec_len : vec.elements.length = 16 := PortableVector_elements_length vec + -- Step 1: read vec[j]. + have h_idx_j : + Aeneas.Std.Array.index_usize vec.elements j = .ok (vec.elements.val[j.val]!) := + array_index_usize_ok_eq vec.elements j (by rw [h_vec_len]; exact h_j) + -- Step 2: classify ζ. + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify zeta = .ok zeta := + classify_ok_eq zeta + -- Step 3: L0.4 montgomery_multiply on (vec[j], ζ). + set b : Std.I16 := vec.elements.val[j.val]! with hb_def + obtain ⟨t, h_t_eq_ok, h_t_bd, _h_t_mod⟩ := + triple_exists_ok_l2 (montgomery_multiply_fe_by_fer_spec b zeta h_zeta) + -- Step 4: read vec[i]. + have h_idx_i : + Aeneas.Std.Array.index_usize vec.elements i = .ok (vec.elements.val[i.val]!) := + array_index_usize_ok_eq vec.elements i (by rw [h_vec_len]; exact h_i) + set a : Std.I16 := vec.elements.val[i.val]! with ha_def + -- Step 5: wrapping_sub a t. + have h_sub_eq : CoreModels.core.num.I16.wrapping_sub a t = .ok (Std.I16.wrapping_sub a t) := + cm_wrapping_sub_ok_eq a t + -- Step 6: wrapping_add a t. + have h_add_eq : CoreModels.core.num.I16.wrapping_add a t = .ok (Std.I16.wrapping_add a t) := + cm_wrapping_add_ok_eq a t + set a_minus_t : Std.I16 := Std.I16.wrapping_sub a t with hamt_def + set a_plus_t : Std.I16 := Std.I16.wrapping_add a t with hapt_def + -- Step 7: update vec at index j with a_minus_t. + have h_upd_j : + Aeneas.Std.Array.update vec.elements j a_minus_t + = .ok (vec.elements.set j a_minus_t) := + array_update_ok_eq vec.elements j a_minus_t (by rw [h_vec_len]; exact h_j) + -- Step 8: update at index i with a_plus_t. + have h_upd_i : + Aeneas.Std.Array.update (vec.elements.set j a_minus_t) i a_plus_t + = .ok ((vec.elements.set j a_minus_t).set i a_plus_t) := by + have h_len : (vec.elements.set j a_minus_t).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact array_update_ok_eq _ i a_plus_t (by rw [h_len]; exact h_i) + -- Compose the whole do-block into one `.ok` equation. + set final_elements : Std.Array Std.I16 16#usize := + (vec.elements.set j a_minus_t).set i a_plus_t with hfe_def + set final_vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { elements := final_elements } with hfv_def + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j = .ok final_vec := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_step + rw [h_idx_j]; simp only [bind_tc_ok] + rw [h_classify]; simp only [bind_tc_ok] + rw [h_t_eq_ok]; simp only [bind_tc_ok] + rw [h_idx_i]; simp only [bind_tc_ok] + rw [h_sub_eq]; simp only [bind_tc_ok] + rw [h_add_eq]; simp only [bind_tc_ok] + rw [h_upd_j]; simp only [bind_tc_ok] + rw [h_upd_i]; simp only [bind_tc_ok]; rfl + -- Bound facts on the two touched lanes. + have h_add := add_no_overflow_value a t h_a h_t_bd + have h_sub := sub_no_overflow_value a t h_a h_t_bd + -- Close the Triple. + apply triple_of_ok_l2 h_body + refine ⟨?_, ?_, ?_⟩ + · -- Unchanged lanes: k ≠ i.val and k ≠ j.val. + intro k hk_lt hk_ne_i hk_ne_j + have h_set_i_ne : + ((vec.elements.set j a_minus_t).set i a_plus_t)[k]! + = (vec.elements.set j a_minus_t)[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ i k _ (Ne.symm hk_ne_i) + have h_set_j_ne : + (vec.elements.set j a_minus_t)[k]! = (vec.elements)[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ j k _ (Ne.symm hk_ne_j) + show final_vec.elements.val[k]! = vec.elements.val[k]! + show ((vec.elements.set j a_minus_t).set i a_plus_t).val[k]! = vec.elements.val[k]! + -- Convert .val[k]! ↔ [k]! using getElem!_Nat_eq, then chain set_ne rewrites. + rw [← Aeneas.Std.Array.getElem!_Nat_eq, ← Aeneas.Std.Array.getElem!_Nat_eq, + h_set_i_ne, h_set_j_ne] + · -- Bound on r.elements.val[i.val]!. + show (final_vec.elements.val[i.val]!).val.natAbs ≤ 4 * 3328 + show (((vec.elements.set j a_minus_t).set i a_plus_t).val[i.val]!).val.natAbs ≤ 4 * 3328 + have h_eq1 : + ((vec.elements.set j a_minus_t).set i a_plus_t).val[i.val]! + = ((vec.elements.set j a_minus_t).set i a_plus_t)[i.val]! := by + simp [Std.Array.getElem!_Nat_eq] + have h_set_i_eq : + ((vec.elements.set j a_minus_t).set i a_plus_t)[i.val]! = a_plus_t := by + have h_len : (vec.elements.set j a_minus_t).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact Aeneas.Std.Array.getElem!_Nat_set_eq _ i i.val _ ⟨rfl, by rw [h_len]; exact h_i⟩ + rw [h_eq1, h_set_i_eq] + exact h_add.2 + · -- Bound on r.elements.val[j.val]!. + show (final_vec.elements.val[j.val]!).val.natAbs ≤ 4 * 3328 + show (((vec.elements.set j a_minus_t).set i a_plus_t).val[j.val]!).val.natAbs ≤ 4 * 3328 + have h_eq1 : + ((vec.elements.set j a_minus_t).set i a_plus_t).val[j.val]! + = ((vec.elements.set j a_minus_t).set i a_plus_t)[j.val]! := by + simp [Std.Array.getElem!_Nat_eq] + have h_ne_ij : i.val ≠ j.val := h_ne + have h_set_i_ne : + ((vec.elements.set j a_minus_t).set i a_plus_t)[j.val]! + = (vec.elements.set j a_minus_t)[j.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ i j.val _ h_ne_ij + have h_set_j_eq : + (vec.elements.set j a_minus_t)[j.val]! = a_minus_t := by + exact Aeneas.Std.Array.getElem!_Nat_set_eq _ j j.val _ ⟨rfl, by rw [h_vec_len]; exact h_j⟩ + have h_eq2 : + (vec.elements.set j a_minus_t)[j.val]! + = (vec.elements.set j a_minus_t).val[j.val]! := by + simp [Std.Array.getElem!_Nat_eq] + rw [h_eq1, h_set_i_ne, h_set_j_eq] + exact h_sub.2 + +/-! ## Parameterised L2.1 — `ntt_step_spec_B` + + Same shape as `ntt_step_spec` but with a configurable lane bound `B ≤ 8`. + Each touched lane goes from `≤ B·3328` to `≤ (B+1)·3328`. Used by the + L2.2/L2.3/L2.4 bundled-butterfly proofs which need different inbound + bounds (7·3328 / 6·3328 / 5·3328 respectively). + + The proof body is identical to `ntt_step_spec` except that + `add_no_overflow_value` / `sub_no_overflow_value` are replaced by their + `_B` counterparts. + + `h_B : B ≤ 8` ensures no I16 overflow: `(B+1)·3328 ≤ 9·3328 = 29952 < 2^15`. +-/ + +@[spec] +theorem ntt_step_spec_B + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i j : Std.Usize) (B : Nat) + (h_i : i.val < 16) (h_j : j.val < 16) (h_ne : i.val ≠ j.val) + (h_zeta : zeta.val.natAbs ≤ 1664) + (h_a : (vec.elements.val[i.val]!).val.natAbs ≤ B * 3328) + (h_b : (vec.elements.val[j.val]!).val.natAbs ≤ B * 3328) + (h_B : B ≤ 8) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j + ⦃ ⇓ r => ⌜ (∀ k : Nat, k < 16 → k ≠ i.val → k ≠ j.val → + r.elements.val[k]! = vec.elements.val[k]!) + ∧ (r.elements.val[i.val]!).val.natAbs ≤ (B + 1) * 3328 + ∧ (r.elements.val[j.val]!).val.natAbs ≤ (B + 1) * 3328 ⌝ ⦄ := by + have h_vec_len : vec.elements.length = 16 := PortableVector_elements_length vec + have h_idx_j : + Aeneas.Std.Array.index_usize vec.elements j = .ok (vec.elements.val[j.val]!) := + array_index_usize_ok_eq vec.elements j (by rw [h_vec_len]; exact h_j) + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify zeta = .ok zeta := + classify_ok_eq zeta + set b : Std.I16 := vec.elements.val[j.val]! with hb_def + obtain ⟨t, h_t_eq_ok, h_t_bd, _h_t_mod⟩ := + triple_exists_ok_l2 (montgomery_multiply_fe_by_fer_spec b zeta h_zeta) + have h_idx_i : + Aeneas.Std.Array.index_usize vec.elements i = .ok (vec.elements.val[i.val]!) := + array_index_usize_ok_eq vec.elements i (by rw [h_vec_len]; exact h_i) + set a : Std.I16 := vec.elements.val[i.val]! with ha_def + have h_sub_eq : CoreModels.core.num.I16.wrapping_sub a t = .ok (Std.I16.wrapping_sub a t) := + cm_wrapping_sub_ok_eq a t + have h_add_eq : CoreModels.core.num.I16.wrapping_add a t = .ok (Std.I16.wrapping_add a t) := + cm_wrapping_add_ok_eq a t + set a_minus_t : Std.I16 := Std.I16.wrapping_sub a t with hamt_def + set a_plus_t : Std.I16 := Std.I16.wrapping_add a t with hapt_def + have h_upd_j : + Aeneas.Std.Array.update vec.elements j a_minus_t + = .ok (vec.elements.set j a_minus_t) := + array_update_ok_eq vec.elements j a_minus_t (by rw [h_vec_len]; exact h_j) + have h_upd_i : + Aeneas.Std.Array.update (vec.elements.set j a_minus_t) i a_plus_t + = .ok ((vec.elements.set j a_minus_t).set i a_plus_t) := by + have h_len : (vec.elements.set j a_minus_t).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact array_update_ok_eq _ i a_plus_t (by rw [h_len]; exact h_i) + set final_elements : Std.Array Std.I16 16#usize := + (vec.elements.set j a_minus_t).set i a_plus_t with hfe_def + set final_vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { elements := final_elements } with hfv_def + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j = .ok final_vec := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_step + rw [h_idx_j]; simp only [bind_tc_ok] + rw [h_classify]; simp only [bind_tc_ok] + rw [h_t_eq_ok]; simp only [bind_tc_ok] + rw [h_idx_i]; simp only [bind_tc_ok] + rw [h_sub_eq]; simp only [bind_tc_ok] + rw [h_add_eq]; simp only [bind_tc_ok] + rw [h_upd_j]; simp only [bind_tc_ok] + rw [h_upd_i]; simp only [bind_tc_ok]; rfl + have h_add := add_no_overflow_value_B a t B h_a h_t_bd h_B + have h_sub := sub_no_overflow_value_B a t B h_a h_t_bd h_B + apply triple_of_ok_l2 h_body + refine ⟨?_, ?_, ?_⟩ + · intro k hk_lt hk_ne_i hk_ne_j + have h_set_i_ne : + ((vec.elements.set j a_minus_t).set i a_plus_t)[k]! + = (vec.elements.set j a_minus_t)[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ i k _ (Ne.symm hk_ne_i) + have h_set_j_ne : + (vec.elements.set j a_minus_t)[k]! = (vec.elements)[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ j k _ (Ne.symm hk_ne_j) + show final_vec.elements.val[k]! = vec.elements.val[k]! + show ((vec.elements.set j a_minus_t).set i a_plus_t).val[k]! = vec.elements.val[k]! + rw [← Aeneas.Std.Array.getElem!_Nat_eq, ← Aeneas.Std.Array.getElem!_Nat_eq, + h_set_i_ne, h_set_j_ne] + · show (final_vec.elements.val[i.val]!).val.natAbs ≤ (B + 1) * 3328 + show (((vec.elements.set j a_minus_t).set i a_plus_t).val[i.val]!).val.natAbs ≤ (B + 1) * 3328 + have h_eq1 : + ((vec.elements.set j a_minus_t).set i a_plus_t).val[i.val]! + = ((vec.elements.set j a_minus_t).set i a_plus_t)[i.val]! := by + simp [Std.Array.getElem!_Nat_eq] + have h_set_i_eq : + ((vec.elements.set j a_minus_t).set i a_plus_t)[i.val]! = a_plus_t := by + have h_len : (vec.elements.set j a_minus_t).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact Aeneas.Std.Array.getElem!_Nat_set_eq _ i i.val _ ⟨rfl, by rw [h_len]; exact h_i⟩ + rw [h_eq1, h_set_i_eq] + exact h_add.2 + · show (final_vec.elements.val[j.val]!).val.natAbs ≤ (B + 1) * 3328 + show (((vec.elements.set j a_minus_t).set i a_plus_t).val[j.val]!).val.natAbs ≤ (B + 1) * 3328 + have h_eq1 : + ((vec.elements.set j a_minus_t).set i a_plus_t).val[j.val]! + = ((vec.elements.set j a_minus_t).set i a_plus_t)[j.val]! := by + simp [Std.Array.getElem!_Nat_eq] + have h_ne_ij : i.val ≠ j.val := h_ne + have h_set_i_ne : + ((vec.elements.set j a_minus_t).set i a_plus_t)[j.val]! + = (vec.elements.set j a_minus_t)[j.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ i j.val _ h_ne_ij + have h_set_j_eq : + (vec.elements.set j a_minus_t)[j.val]! = a_minus_t := by + exact Aeneas.Std.Array.getElem!_Nat_set_eq _ j j.val _ ⟨rfl, by rw [h_vec_len]; exact h_j⟩ + have h_eq2 : + (vec.elements.set j a_minus_t)[j.val]! + = (vec.elements.set j a_minus_t).val[j.val]! := by + simp [Std.Array.getElem!_Nat_eq] + rw [h_eq1, h_set_i_ne, h_set_j_eq] + exact h_sub.2 + +/-! ## Truly bnd-parameterised L2.1 — `ntt_step_spec_bnd` + + Same shape as `ntt_step_spec_B` but with the lane bound stated as a raw + `Nat` `bnd` (instead of `B * 3328`). Used by the `L3.{1,2,3}_B` + parameterisations which carry an `bnd : Nat` invariant rather than a + multiplicative `B * 3328` shape. + + The proof body mirrors `ntt_step_spec_B` exactly, swapping the calls to + `add_no_overflow_value_B` / `sub_no_overflow_value_B` for their `_bnd` + counterparts and the `B ≤ 8` precondition for `bnd ≤ 29439`. + + The bound `29439 = 32767 - 3328` ensures `(bnd + 3328) ≤ 32767 < 2^15`, + so the I16 wrapping is the identity for both `a + t` and `a - t`. -/ + +@[spec] +theorem ntt_step_spec_bnd + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i j : Std.Usize) (bnd : Nat) + (h_i : i.val < 16) (h_j : j.val < 16) (h_ne : i.val ≠ j.val) + (h_zeta : zeta.val.natAbs ≤ 1664) + (h_a : (vec.elements.val[i.val]!).val.natAbs ≤ bnd) + (h_b : (vec.elements.val[j.val]!).val.natAbs ≤ bnd) + (h_bnd : bnd ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j + ⦃ ⇓ r => ⌜ (∀ k : Nat, k < 16 → k ≠ i.val → k ≠ j.val → + r.elements.val[k]! = vec.elements.val[k]!) + ∧ (r.elements.val[i.val]!).val.natAbs ≤ bnd + 3328 + ∧ (r.elements.val[j.val]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + have h_vec_len : vec.elements.length = 16 := PortableVector_elements_length vec + have h_idx_j : + Aeneas.Std.Array.index_usize vec.elements j = .ok (vec.elements.val[j.val]!) := + array_index_usize_ok_eq vec.elements j (by rw [h_vec_len]; exact h_j) + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify zeta = .ok zeta := + classify_ok_eq zeta + set b : Std.I16 := vec.elements.val[j.val]! with hb_def + obtain ⟨t, h_t_eq_ok, h_t_bd, _h_t_mod⟩ := + triple_exists_ok_l2 (montgomery_multiply_fe_by_fer_spec b zeta h_zeta) + have h_idx_i : + Aeneas.Std.Array.index_usize vec.elements i = .ok (vec.elements.val[i.val]!) := + array_index_usize_ok_eq vec.elements i (by rw [h_vec_len]; exact h_i) + set a : Std.I16 := vec.elements.val[i.val]! with ha_def + have h_sub_eq : CoreModels.core.num.I16.wrapping_sub a t = .ok (Std.I16.wrapping_sub a t) := + cm_wrapping_sub_ok_eq a t + have h_add_eq : CoreModels.core.num.I16.wrapping_add a t = .ok (Std.I16.wrapping_add a t) := + cm_wrapping_add_ok_eq a t + set a_minus_t : Std.I16 := Std.I16.wrapping_sub a t with hamt_def + set a_plus_t : Std.I16 := Std.I16.wrapping_add a t with hapt_def + have h_upd_j : + Aeneas.Std.Array.update vec.elements j a_minus_t + = .ok (vec.elements.set j a_minus_t) := + array_update_ok_eq vec.elements j a_minus_t (by rw [h_vec_len]; exact h_j) + have h_upd_i : + Aeneas.Std.Array.update (vec.elements.set j a_minus_t) i a_plus_t + = .ok ((vec.elements.set j a_minus_t).set i a_plus_t) := by + have h_len : (vec.elements.set j a_minus_t).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact array_update_ok_eq _ i a_plus_t (by rw [h_len]; exact h_i) + set final_elements : Std.Array Std.I16 16#usize := + (vec.elements.set j a_minus_t).set i a_plus_t with hfe_def + set final_vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { elements := final_elements } with hfv_def + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j = .ok final_vec := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_step + rw [h_idx_j]; simp only [bind_tc_ok] + rw [h_classify]; simp only [bind_tc_ok] + rw [h_t_eq_ok]; simp only [bind_tc_ok] + rw [h_idx_i]; simp only [bind_tc_ok] + rw [h_sub_eq]; simp only [bind_tc_ok] + rw [h_add_eq]; simp only [bind_tc_ok] + rw [h_upd_j]; simp only [bind_tc_ok] + rw [h_upd_i]; simp only [bind_tc_ok]; rfl + have h_add := add_no_overflow_value_bnd a t bnd h_a h_t_bd h_bnd + have h_sub := sub_no_overflow_value_bnd a t bnd h_a h_t_bd h_bnd + apply triple_of_ok_l2 h_body + refine ⟨?_, ?_, ?_⟩ + · intro k hk_lt hk_ne_i hk_ne_j + have h_set_i_ne : + ((vec.elements.set j a_minus_t).set i a_plus_t)[k]! + = (vec.elements.set j a_minus_t)[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ i k _ (Ne.symm hk_ne_i) + have h_set_j_ne : + (vec.elements.set j a_minus_t)[k]! = (vec.elements)[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ j k _ (Ne.symm hk_ne_j) + show final_vec.elements.val[k]! = vec.elements.val[k]! + show ((vec.elements.set j a_minus_t).set i a_plus_t).val[k]! = vec.elements.val[k]! + rw [← Aeneas.Std.Array.getElem!_Nat_eq, ← Aeneas.Std.Array.getElem!_Nat_eq, + h_set_i_ne, h_set_j_ne] + · show (final_vec.elements.val[i.val]!).val.natAbs ≤ bnd + 3328 + show (((vec.elements.set j a_minus_t).set i a_plus_t).val[i.val]!).val.natAbs ≤ bnd + 3328 + have h_eq1 : + ((vec.elements.set j a_minus_t).set i a_plus_t).val[i.val]! + = ((vec.elements.set j a_minus_t).set i a_plus_t)[i.val]! := by + simp [Std.Array.getElem!_Nat_eq] + have h_set_i_eq : + ((vec.elements.set j a_minus_t).set i a_plus_t)[i.val]! = a_plus_t := by + have h_len : (vec.elements.set j a_minus_t).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact Aeneas.Std.Array.getElem!_Nat_set_eq _ i i.val _ ⟨rfl, by rw [h_len]; exact h_i⟩ + rw [h_eq1, h_set_i_eq] + exact h_add.2 + · show (final_vec.elements.val[j.val]!).val.natAbs ≤ bnd + 3328 + show (((vec.elements.set j a_minus_t).set i a_plus_t).val[j.val]!).val.natAbs ≤ bnd + 3328 + have h_eq1 : + ((vec.elements.set j a_minus_t).set i a_plus_t).val[j.val]! + = ((vec.elements.set j a_minus_t).set i a_plus_t)[j.val]! := by + simp [Std.Array.getElem!_Nat_eq] + have h_ne_ij : i.val ≠ j.val := h_ne + have h_set_i_ne : + ((vec.elements.set j a_minus_t).set i a_plus_t)[j.val]! + = (vec.elements.set j a_minus_t)[j.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ i j.val _ h_ne_ij + have h_set_j_eq : + (vec.elements.set j a_minus_t)[j.val]! = a_minus_t := by + exact Aeneas.Std.Array.getElem!_Nat_set_eq _ j j.val _ ⟨rfl, by rw [h_vec_len]; exact h_j⟩ + have h_eq2 : + (vec.elements.set j a_minus_t)[j.val]! + = (vec.elements.set j a_minus_t).val[j.val]! := by + simp [Std.Array.getElem!_Nat_eq] + rw [h_eq1, h_set_i_ne, h_set_j_eq] + exact h_sub.2 + +/-! ## L2.2 — `ntt_layer_1_step_spec` + + Eight disjoint butterflies on pairs `(0,2)ζ0`, `(1,3)ζ0`, `(4,6)ζ1`, + `(5,7)ζ1`, `(8,10)ζ2`, `(9,11)ζ2`, `(12,14)ζ3`, `(13,15)ζ3`. Each call + of `ntt_step_spec_B` with `B = 7` raises the two touched lanes from + `≤ 7·3328` to `≤ 8·3328`; pairs are pairwise disjoint and cover all + 16 lanes, so every lane is touched exactly once. + + Bookkeeping idiom: we maintain, after the k-th call, the invariant + "for every lane index ℓ ∈ [0,16), if ℓ is among the lanes touched + so far, lane_k[ℓ] ≤ 8·3328; else lane_k[ℓ] = vec[ℓ] (and so + ≤ 7·3328 ≤ 8·3328)". The post conjuncts of `ntt_step_spec_B` + immediately give us each step's contribution. +-/ + +@[spec] +theorem ntt_layer_1_step_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 zeta2 zeta3 : Std.I16) + (hz0 : zeta0.val.natAbs ≤ 1664) (hz1 : zeta1.val.natAbs ≤ 1664) + (hz2 : zeta2.val.natAbs ≤ 1664) (hz3 : zeta3.val.natAbs ≤ 1664) + (hpre : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 7 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step vec zeta0 zeta1 zeta2 zeta3 + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ 8 * 3328 ⌝ ⦄ := by + -- Index abbreviations. + have h0lt : (0#usize : Std.Usize).val < 16 := by decide + have h1lt : (1#usize : Std.Usize).val < 16 := by decide + have h2lt : (2#usize : Std.Usize).val < 16 := by decide + have h3lt : (3#usize : Std.Usize).val < 16 := by decide + have h4lt : (4#usize : Std.Usize).val < 16 := by decide + have h5lt : (5#usize : Std.Usize).val < 16 := by decide + have h6lt : (6#usize : Std.Usize).val < 16 := by decide + have h7lt : (7#usize : Std.Usize).val < 16 := by decide + have h8lt : (8#usize : Std.Usize).val < 16 := by decide + have h9lt : (9#usize : Std.Usize).val < 16 := by decide + have h10lt : (10#usize : Std.Usize).val < 16 := by decide + have h11lt : (11#usize : Std.Usize).val < 16 := by decide + have h12lt : (12#usize : Std.Usize).val < 16 := by decide + have h13lt : (13#usize : Std.Usize).val < 16 := by decide + have h14lt : (14#usize : Std.Usize).val < 16 := by decide + have h15lt : (15#usize : Std.Usize).val < 16 := by decide + -- Initial bounds: hpre gives ≤ 7·3328 on all lanes. + have hb_init : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 7 * 3328 := hpre + -- Step 1: (0, 2) ζ0. Pair untouched ⇒ both lanes ≤ 7·3328 from hpre. + obtain ⟨v1, h_v1_eq, h_v1_unc, h_v1_i, h_v1_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B vec zeta0 0#usize 2#usize 7 + h0lt h2lt (by decide) hz0 (hb_init 0 h0lt) (hb_init 2 h2lt) (by decide)) + -- After step 1: lane 0 and lane 2 are ≤ 8·3328 (h_v1_i, h_v1_j); other lanes + -- unchanged from vec (h_v1_unc), so still ≤ 7·3328 from hpre. + -- Step 2: (1, 3) ζ0. Disjoint from {0, 2}, so lanes 1, 3 are still vec[1], vec[3] ≤ 7·3328. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v1_unc 1 h1lt (by decide) (by decide)]; exact hb_init 1 h1lt + have h_v1_3 : (v1.elements.val[3]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v1_unc 3 h3lt (by decide) (by decide)]; exact hb_init 3 h3lt + obtain ⟨v2, h_v2_eq, h_v2_unc, h_v2_i, h_v2_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v1 zeta0 1#usize 3#usize 7 + h1lt h3lt (by decide) hz0 h_v1_1 h_v1_3 (by decide)) + -- Step 3: (4, 6) ζ1. v2[4], v2[6] not touched in steps 1,2 ⇒ = vec[4], vec[6] ≤ 7·3328. + have h_v2_4 : (v2.elements.val[4]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v2_unc 4 h4lt (by decide) (by decide), h_v1_unc 4 h4lt (by decide) (by decide)] + exact hb_init 4 h4lt + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v2_unc 6 h6lt (by decide) (by decide), h_v1_unc 6 h6lt (by decide) (by decide)] + exact hb_init 6 h6lt + obtain ⟨v3, h_v3_eq, h_v3_unc, h_v3_i, h_v3_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v2 zeta1 4#usize 6#usize 7 + h4lt h6lt (by decide) hz1 h_v2_4 h_v2_6 (by decide)) + -- Step 4: (5, 7) ζ1. + have h_v3_5 : (v3.elements.val[5]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v3_unc 5 h5lt (by decide) (by decide), + h_v2_unc 5 h5lt (by decide) (by decide), + h_v1_unc 5 h5lt (by decide) (by decide)] + exact hb_init 5 h5lt + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v3_unc 7 h7lt (by decide) (by decide), + h_v2_unc 7 h7lt (by decide) (by decide), + h_v1_unc 7 h7lt (by decide) (by decide)] + exact hb_init 7 h7lt + obtain ⟨v4, h_v4_eq, h_v4_unc, h_v4_i, h_v4_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v3 zeta1 5#usize 7#usize 7 + h5lt h7lt (by decide) hz1 h_v3_5 h_v3_7 (by decide)) + -- Step 5: (8, 10) ζ2. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v4_unc 8 h8lt (by decide) (by decide), + h_v3_unc 8 h8lt (by decide) (by decide), + h_v2_unc 8 h8lt (by decide) (by decide), + h_v1_unc 8 h8lt (by decide) (by decide)] + exact hb_init 8 h8lt + have h_v4_10 : (v4.elements.val[10]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v4_unc 10 h10lt (by decide) (by decide), + h_v3_unc 10 h10lt (by decide) (by decide), + h_v2_unc 10 h10lt (by decide) (by decide), + h_v1_unc 10 h10lt (by decide) (by decide)] + exact hb_init 10 h10lt + obtain ⟨v5, h_v5_eq, h_v5_unc, h_v5_i, h_v5_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v4 zeta2 8#usize 10#usize 7 + h8lt h10lt (by decide) hz2 h_v4_8 h_v4_10 (by decide)) + -- Step 6: (9, 11) ζ2. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v5_unc 9 h9lt (by decide) (by decide), + h_v4_unc 9 h9lt (by decide) (by decide), + h_v3_unc 9 h9lt (by decide) (by decide), + h_v2_unc 9 h9lt (by decide) (by decide), + h_v1_unc 9 h9lt (by decide) (by decide)] + exact hb_init 9 h9lt + have h_v5_11 : (v5.elements.val[11]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v5_unc 11 h11lt (by decide) (by decide), + h_v4_unc 11 h11lt (by decide) (by decide), + h_v3_unc 11 h11lt (by decide) (by decide), + h_v2_unc 11 h11lt (by decide) (by decide), + h_v1_unc 11 h11lt (by decide) (by decide)] + exact hb_init 11 h11lt + obtain ⟨v6, h_v6_eq, h_v6_unc, h_v6_i, h_v6_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v5 zeta2 9#usize 11#usize 7 + h9lt h11lt (by decide) hz2 h_v5_9 h_v5_11 (by decide)) + -- Step 7: (12, 14) ζ3. + have h_v6_12 : (v6.elements.val[12]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v6_unc 12 h12lt (by decide) (by decide), + h_v5_unc 12 h12lt (by decide) (by decide), + h_v4_unc 12 h12lt (by decide) (by decide), + h_v3_unc 12 h12lt (by decide) (by decide), + h_v2_unc 12 h12lt (by decide) (by decide), + h_v1_unc 12 h12lt (by decide) (by decide)] + exact hb_init 12 h12lt + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v6_unc 14 h14lt (by decide) (by decide), + h_v5_unc 14 h14lt (by decide) (by decide), + h_v4_unc 14 h14lt (by decide) (by decide), + h_v3_unc 14 h14lt (by decide) (by decide), + h_v2_unc 14 h14lt (by decide) (by decide), + h_v1_unc 14 h14lt (by decide) (by decide)] + exact hb_init 14 h14lt + obtain ⟨v7, h_v7_eq, h_v7_unc, h_v7_i, h_v7_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v6 zeta3 12#usize 14#usize 7 + h12lt h14lt (by decide) hz3 h_v6_12 h_v6_14 (by decide)) + -- Step 8: (13, 15) ζ3. + have h_v7_13 : (v7.elements.val[13]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v7_unc 13 h13lt (by decide) (by decide), + h_v6_unc 13 h13lt (by decide) (by decide), + h_v5_unc 13 h13lt (by decide) (by decide), + h_v4_unc 13 h13lt (by decide) (by decide), + h_v3_unc 13 h13lt (by decide) (by decide), + h_v2_unc 13 h13lt (by decide) (by decide), + h_v1_unc 13 h13lt (by decide) (by decide)] + exact hb_init 13 h13lt + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 7 * 3328 := by + rw [h_v7_unc 15 h15lt (by decide) (by decide), + h_v6_unc 15 h15lt (by decide) (by decide), + h_v5_unc 15 h15lt (by decide) (by decide), + h_v4_unc 15 h15lt (by decide) (by decide), + h_v3_unc 15 h15lt (by decide) (by decide), + h_v2_unc 15 h15lt (by decide) (by decide), + h_v1_unc 15 h15lt (by decide) (by decide)] + exact hb_init 15 h15lt + obtain ⟨v8, h_v8_eq, h_v8_unc, h_v8_i, h_v8_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v7 zeta3 13#usize 15#usize 7 + h13lt h15lt (by decide) hz3 h_v7_13 h_v7_15 (by decide)) + -- Compose the whole 8-step chain into one `.ok v8` equation. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step vec zeta0 zeta1 zeta2 zeta3 + = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step + rw [h_v1_eq]; simp only [bind_tc_ok] + rw [h_v2_eq]; simp only [bind_tc_ok] + rw [h_v3_eq]; simp only [bind_tc_ok] + rw [h_v4_eq]; simp only [bind_tc_ok] + rw [h_v5_eq]; simp only [bind_tc_ok] + rw [h_v6_eq]; simp only [bind_tc_ok] + rw [h_v7_eq]; simp only [bind_tc_ok] + exact h_v8_eq + -- Close the Triple: prove every lane ≤ 8·3328 by case-split on which step touched it. + -- Strategy: for each lane ℓ, identify the step that touched it (giving h_v{k}_i or h_v{k}_j + -- with bound ≤ 8·3328), then propagate v_k[ℓ] = ... = v8[ℓ] via the later steps' h_v{m}_unc. + apply triple_of_ok_l2 h_body + intro i hi + interval_cases i + -- Lane 0: touched in step 1 as i-lane ⇒ v1[0] ≤ 8·3328. v8[0] = v1[0]. + · have h_eq : v8.elements.val[0]! = v1.elements.val[0]! := by + rw [h_v8_unc 0 h0lt (by decide) (by decide), + h_v7_unc 0 h0lt (by decide) (by decide), + h_v6_unc 0 h0lt (by decide) (by decide), + h_v5_unc 0 h0lt (by decide) (by decide), + h_v4_unc 0 h0lt (by decide) (by decide), + h_v3_unc 0 h0lt (by decide) (by decide), + h_v2_unc 0 h0lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_i + -- Lane 1: touched in step 2 as i-lane. + · have h_eq : v8.elements.val[1]! = v2.elements.val[1]! := by + rw [h_v8_unc 1 h1lt (by decide) (by decide), + h_v7_unc 1 h1lt (by decide) (by decide), + h_v6_unc 1 h1lt (by decide) (by decide), + h_v5_unc 1 h1lt (by decide) (by decide), + h_v4_unc 1 h1lt (by decide) (by decide), + h_v3_unc 1 h1lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_i + -- Lane 2: touched in step 1 as j-lane. + · have h_eq : v8.elements.val[2]! = v1.elements.val[2]! := by + rw [h_v8_unc 2 h2lt (by decide) (by decide), + h_v7_unc 2 h2lt (by decide) (by decide), + h_v6_unc 2 h2lt (by decide) (by decide), + h_v5_unc 2 h2lt (by decide) (by decide), + h_v4_unc 2 h2lt (by decide) (by decide), + h_v3_unc 2 h2lt (by decide) (by decide), + h_v2_unc 2 h2lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_j + -- Lane 3: touched in step 2 as j-lane. + · have h_eq : v8.elements.val[3]! = v2.elements.val[3]! := by + rw [h_v8_unc 3 h3lt (by decide) (by decide), + h_v7_unc 3 h3lt (by decide) (by decide), + h_v6_unc 3 h3lt (by decide) (by decide), + h_v5_unc 3 h3lt (by decide) (by decide), + h_v4_unc 3 h3lt (by decide) (by decide), + h_v3_unc 3 h3lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_j + -- Lane 4: touched in step 3 as i-lane. + · have h_eq : v8.elements.val[4]! = v3.elements.val[4]! := by + rw [h_v8_unc 4 h4lt (by decide) (by decide), + h_v7_unc 4 h4lt (by decide) (by decide), + h_v6_unc 4 h4lt (by decide) (by decide), + h_v5_unc 4 h4lt (by decide) (by decide), + h_v4_unc 4 h4lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_i + -- Lane 5: touched in step 4 as i-lane. + · have h_eq : v8.elements.val[5]! = v4.elements.val[5]! := by + rw [h_v8_unc 5 h5lt (by decide) (by decide), + h_v7_unc 5 h5lt (by decide) (by decide), + h_v6_unc 5 h5lt (by decide) (by decide), + h_v5_unc 5 h5lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_i + -- Lane 6: touched in step 3 as j-lane. + · have h_eq : v8.elements.val[6]! = v3.elements.val[6]! := by + rw [h_v8_unc 6 h6lt (by decide) (by decide), + h_v7_unc 6 h6lt (by decide) (by decide), + h_v6_unc 6 h6lt (by decide) (by decide), + h_v5_unc 6 h6lt (by decide) (by decide), + h_v4_unc 6 h6lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_j + -- Lane 7: touched in step 4 as j-lane. + · have h_eq : v8.elements.val[7]! = v4.elements.val[7]! := by + rw [h_v8_unc 7 h7lt (by decide) (by decide), + h_v7_unc 7 h7lt (by decide) (by decide), + h_v6_unc 7 h7lt (by decide) (by decide), + h_v5_unc 7 h7lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_j + -- Lane 8: touched in step 5 as i-lane. + · have h_eq : v8.elements.val[8]! = v5.elements.val[8]! := by + rw [h_v8_unc 8 h8lt (by decide) (by decide), + h_v7_unc 8 h8lt (by decide) (by decide), + h_v6_unc 8 h8lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_i + -- Lane 9: touched in step 6 as i-lane. + · have h_eq : v8.elements.val[9]! = v6.elements.val[9]! := by + rw [h_v8_unc 9 h9lt (by decide) (by decide), + h_v7_unc 9 h9lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_i + -- Lane 10: touched in step 5 as j-lane. + · have h_eq : v8.elements.val[10]! = v5.elements.val[10]! := by + rw [h_v8_unc 10 h10lt (by decide) (by decide), + h_v7_unc 10 h10lt (by decide) (by decide), + h_v6_unc 10 h10lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_j + -- Lane 11: touched in step 6 as j-lane. + · have h_eq : v8.elements.val[11]! = v6.elements.val[11]! := by + rw [h_v8_unc 11 h11lt (by decide) (by decide), + h_v7_unc 11 h11lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_j + -- Lane 12: touched in step 7 as i-lane. + · have h_eq : v8.elements.val[12]! = v7.elements.val[12]! := by + rw [h_v8_unc 12 h12lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_i + -- Lane 13: touched in step 8 as i-lane. + · exact h_v8_i + -- Lane 14: touched in step 7 as j-lane. + · have h_eq : v8.elements.val[14]! = v7.elements.val[14]! := by + rw [h_v8_unc 14 h14lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_j + -- Lane 15: touched in step 8 as j-lane. + · exact h_v8_j + +/-! ## L2.2.bnd — `ntt_layer_1_step_spec_bnd` + + Nat-bnd parameterised mirror of `ntt_layer_1_step_spec` (L2.2): same eight + disjoint butterflies on pairs `(0,2)ζ0`, `(1,3)ζ0`, `(4,6)ζ1`, `(5,7)ζ1`, + `(8,10)ζ2`, `(9,11)ζ2`, `(12,14)ζ3`, `(13,15)ζ3`, dispatched via the + `_bnd` form of the per-butterfly Triple. Each call raises the two touched + lanes from `≤ bnd` to `≤ bnd + 3328`. + + Precondition `bnd ≤ 8 * 3328 = 26624` keeps the output bound + `bnd + 3328 ≤ 9 * 3328 = 29952` within `ntt_step_spec_bnd`'s safe range + (`bnd' ≤ 29439` for the I16 no-overflow argument); `26624 ≤ 29439`. +-/ + +@[spec] +theorem ntt_layer_1_step_spec_bnd + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 zeta2 zeta3 : Std.I16) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (hz0 : zeta0.val.natAbs ≤ 1664) (hz1 : zeta1.val.natAbs ≤ 1664) + (hz2 : zeta2.val.natAbs ≤ 1664) (hz3 : zeta3.val.natAbs ≤ 1664) + (hpre : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step vec zeta0 zeta1 zeta2 zeta3 + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + -- Index abbreviations. + have h0lt : (0#usize : Std.Usize).val < 16 := by decide + have h1lt : (1#usize : Std.Usize).val < 16 := by decide + have h2lt : (2#usize : Std.Usize).val < 16 := by decide + have h3lt : (3#usize : Std.Usize).val < 16 := by decide + have h4lt : (4#usize : Std.Usize).val < 16 := by decide + have h5lt : (5#usize : Std.Usize).val < 16 := by decide + have h6lt : (6#usize : Std.Usize).val < 16 := by decide + have h7lt : (7#usize : Std.Usize).val < 16 := by decide + have h8lt : (8#usize : Std.Usize).val < 16 := by decide + have h9lt : (9#usize : Std.Usize).val < 16 := by decide + have h10lt : (10#usize : Std.Usize).val < 16 := by decide + have h11lt : (11#usize : Std.Usize).val < 16 := by decide + have h12lt : (12#usize : Std.Usize).val < 16 := by decide + have h13lt : (13#usize : Std.Usize).val < 16 := by decide + have h14lt : (14#usize : Std.Usize).val < 16 := by decide + have h15lt : (15#usize : Std.Usize).val < 16 := by decide + -- Bridge: bnd ≤ 8*3328 = 26624 ≤ 29439 (= ntt_step_spec_bnd's max input bound). + have h_bnd29439 : bnd ≤ 29439 := by omega + -- Initial bounds: hpre gives ≤ bnd on all lanes. + have hb_init : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ bnd := hpre + -- Step 1: (0, 2) ζ0. Pair untouched ⇒ both lanes ≤ bnd from hpre. + obtain ⟨v1, h_v1_eq, h_v1_unc, h_v1_i, h_v1_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd vec zeta0 0#usize 2#usize bnd + h0lt h2lt (by decide) hz0 (hb_init 0 h0lt) (hb_init 2 h2lt) h_bnd29439) + -- After step 1: lane 0 and lane 2 are ≤ bnd + 3328 (h_v1_i, h_v1_j); other lanes + -- unchanged from vec (h_v1_unc), so still ≤ bnd from hpre. + -- Step 2: (1, 3) ζ0. Disjoint from {0, 2}, so lanes 1, 3 are still vec[1], vec[3] ≤ bnd. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ bnd := by + rw [h_v1_unc 1 h1lt (by decide) (by decide)]; exact hb_init 1 h1lt + have h_v1_3 : (v1.elements.val[3]!).val.natAbs ≤ bnd := by + rw [h_v1_unc 3 h3lt (by decide) (by decide)]; exact hb_init 3 h3lt + obtain ⟨v2, h_v2_eq, h_v2_unc, h_v2_i, h_v2_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v1 zeta0 1#usize 3#usize bnd + h1lt h3lt (by decide) hz0 h_v1_1 h_v1_3 h_bnd29439) + -- Step 3: (4, 6) ζ1. v2[4], v2[6] not touched in steps 1,2 ⇒ = vec[4], vec[6] ≤ bnd. + have h_v2_4 : (v2.elements.val[4]!).val.natAbs ≤ bnd := by + rw [h_v2_unc 4 h4lt (by decide) (by decide), h_v1_unc 4 h4lt (by decide) (by decide)] + exact hb_init 4 h4lt + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ bnd := by + rw [h_v2_unc 6 h6lt (by decide) (by decide), h_v1_unc 6 h6lt (by decide) (by decide)] + exact hb_init 6 h6lt + obtain ⟨v3, h_v3_eq, h_v3_unc, h_v3_i, h_v3_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v2 zeta1 4#usize 6#usize bnd + h4lt h6lt (by decide) hz1 h_v2_4 h_v2_6 h_bnd29439) + -- Step 4: (5, 7) ζ1. + have h_v3_5 : (v3.elements.val[5]!).val.natAbs ≤ bnd := by + rw [h_v3_unc 5 h5lt (by decide) (by decide), + h_v2_unc 5 h5lt (by decide) (by decide), + h_v1_unc 5 h5lt (by decide) (by decide)] + exact hb_init 5 h5lt + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ bnd := by + rw [h_v3_unc 7 h7lt (by decide) (by decide), + h_v2_unc 7 h7lt (by decide) (by decide), + h_v1_unc 7 h7lt (by decide) (by decide)] + exact hb_init 7 h7lt + obtain ⟨v4, h_v4_eq, h_v4_unc, h_v4_i, h_v4_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v3 zeta1 5#usize 7#usize bnd + h5lt h7lt (by decide) hz1 h_v3_5 h_v3_7 h_bnd29439) + -- Step 5: (8, 10) ζ2. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ bnd := by + rw [h_v4_unc 8 h8lt (by decide) (by decide), + h_v3_unc 8 h8lt (by decide) (by decide), + h_v2_unc 8 h8lt (by decide) (by decide), + h_v1_unc 8 h8lt (by decide) (by decide)] + exact hb_init 8 h8lt + have h_v4_10 : (v4.elements.val[10]!).val.natAbs ≤ bnd := by + rw [h_v4_unc 10 h10lt (by decide) (by decide), + h_v3_unc 10 h10lt (by decide) (by decide), + h_v2_unc 10 h10lt (by decide) (by decide), + h_v1_unc 10 h10lt (by decide) (by decide)] + exact hb_init 10 h10lt + obtain ⟨v5, h_v5_eq, h_v5_unc, h_v5_i, h_v5_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v4 zeta2 8#usize 10#usize bnd + h8lt h10lt (by decide) hz2 h_v4_8 h_v4_10 h_bnd29439) + -- Step 6: (9, 11) ζ2. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ bnd := by + rw [h_v5_unc 9 h9lt (by decide) (by decide), + h_v4_unc 9 h9lt (by decide) (by decide), + h_v3_unc 9 h9lt (by decide) (by decide), + h_v2_unc 9 h9lt (by decide) (by decide), + h_v1_unc 9 h9lt (by decide) (by decide)] + exact hb_init 9 h9lt + have h_v5_11 : (v5.elements.val[11]!).val.natAbs ≤ bnd := by + rw [h_v5_unc 11 h11lt (by decide) (by decide), + h_v4_unc 11 h11lt (by decide) (by decide), + h_v3_unc 11 h11lt (by decide) (by decide), + h_v2_unc 11 h11lt (by decide) (by decide), + h_v1_unc 11 h11lt (by decide) (by decide)] + exact hb_init 11 h11lt + obtain ⟨v6, h_v6_eq, h_v6_unc, h_v6_i, h_v6_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v5 zeta2 9#usize 11#usize bnd + h9lt h11lt (by decide) hz2 h_v5_9 h_v5_11 h_bnd29439) + -- Step 7: (12, 14) ζ3. + have h_v6_12 : (v6.elements.val[12]!).val.natAbs ≤ bnd := by + rw [h_v6_unc 12 h12lt (by decide) (by decide), + h_v5_unc 12 h12lt (by decide) (by decide), + h_v4_unc 12 h12lt (by decide) (by decide), + h_v3_unc 12 h12lt (by decide) (by decide), + h_v2_unc 12 h12lt (by decide) (by decide), + h_v1_unc 12 h12lt (by decide) (by decide)] + exact hb_init 12 h12lt + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ bnd := by + rw [h_v6_unc 14 h14lt (by decide) (by decide), + h_v5_unc 14 h14lt (by decide) (by decide), + h_v4_unc 14 h14lt (by decide) (by decide), + h_v3_unc 14 h14lt (by decide) (by decide), + h_v2_unc 14 h14lt (by decide) (by decide), + h_v1_unc 14 h14lt (by decide) (by decide)] + exact hb_init 14 h14lt + obtain ⟨v7, h_v7_eq, h_v7_unc, h_v7_i, h_v7_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v6 zeta3 12#usize 14#usize bnd + h12lt h14lt (by decide) hz3 h_v6_12 h_v6_14 h_bnd29439) + -- Step 8: (13, 15) ζ3. + have h_v7_13 : (v7.elements.val[13]!).val.natAbs ≤ bnd := by + rw [h_v7_unc 13 h13lt (by decide) (by decide), + h_v6_unc 13 h13lt (by decide) (by decide), + h_v5_unc 13 h13lt (by decide) (by decide), + h_v4_unc 13 h13lt (by decide) (by decide), + h_v3_unc 13 h13lt (by decide) (by decide), + h_v2_unc 13 h13lt (by decide) (by decide), + h_v1_unc 13 h13lt (by decide) (by decide)] + exact hb_init 13 h13lt + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ bnd := by + rw [h_v7_unc 15 h15lt (by decide) (by decide), + h_v6_unc 15 h15lt (by decide) (by decide), + h_v5_unc 15 h15lt (by decide) (by decide), + h_v4_unc 15 h15lt (by decide) (by decide), + h_v3_unc 15 h15lt (by decide) (by decide), + h_v2_unc 15 h15lt (by decide) (by decide), + h_v1_unc 15 h15lt (by decide) (by decide)] + exact hb_init 15 h15lt + obtain ⟨v8, h_v8_eq, h_v8_unc, h_v8_i, h_v8_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v7 zeta3 13#usize 15#usize bnd + h13lt h15lt (by decide) hz3 h_v7_13 h_v7_15 h_bnd29439) + -- Compose the whole 8-step chain into one `.ok v8` equation. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step vec zeta0 zeta1 zeta2 zeta3 + = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step + rw [h_v1_eq]; simp only [bind_tc_ok] + rw [h_v2_eq]; simp only [bind_tc_ok] + rw [h_v3_eq]; simp only [bind_tc_ok] + rw [h_v4_eq]; simp only [bind_tc_ok] + rw [h_v5_eq]; simp only [bind_tc_ok] + rw [h_v6_eq]; simp only [bind_tc_ok] + rw [h_v7_eq]; simp only [bind_tc_ok] + exact h_v8_eq + -- Close the Triple: prove every lane ≤ bnd + 3328 by case-split on which step touched it. + -- Strategy: for each lane ℓ, identify the step that touched it (giving h_v{k}_i or h_v{k}_j + -- with bound ≤ bnd + 3328), then propagate v_k[ℓ] = ... = v8[ℓ] via the later steps' h_v{m}_unc. + apply triple_of_ok_l2 h_body + intro i hi + interval_cases i + -- Lane 0: touched in step 1 as i-lane ⇒ v1[0] ≤ bnd + 3328. v8[0] = v1[0]. + · have h_eq : v8.elements.val[0]! = v1.elements.val[0]! := by + rw [h_v8_unc 0 h0lt (by decide) (by decide), + h_v7_unc 0 h0lt (by decide) (by decide), + h_v6_unc 0 h0lt (by decide) (by decide), + h_v5_unc 0 h0lt (by decide) (by decide), + h_v4_unc 0 h0lt (by decide) (by decide), + h_v3_unc 0 h0lt (by decide) (by decide), + h_v2_unc 0 h0lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_i + -- Lane 1: touched in step 2 as i-lane. + · have h_eq : v8.elements.val[1]! = v2.elements.val[1]! := by + rw [h_v8_unc 1 h1lt (by decide) (by decide), + h_v7_unc 1 h1lt (by decide) (by decide), + h_v6_unc 1 h1lt (by decide) (by decide), + h_v5_unc 1 h1lt (by decide) (by decide), + h_v4_unc 1 h1lt (by decide) (by decide), + h_v3_unc 1 h1lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_i + -- Lane 2: touched in step 1 as j-lane. + · have h_eq : v8.elements.val[2]! = v1.elements.val[2]! := by + rw [h_v8_unc 2 h2lt (by decide) (by decide), + h_v7_unc 2 h2lt (by decide) (by decide), + h_v6_unc 2 h2lt (by decide) (by decide), + h_v5_unc 2 h2lt (by decide) (by decide), + h_v4_unc 2 h2lt (by decide) (by decide), + h_v3_unc 2 h2lt (by decide) (by decide), + h_v2_unc 2 h2lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_j + -- Lane 3: touched in step 2 as j-lane. + · have h_eq : v8.elements.val[3]! = v2.elements.val[3]! := by + rw [h_v8_unc 3 h3lt (by decide) (by decide), + h_v7_unc 3 h3lt (by decide) (by decide), + h_v6_unc 3 h3lt (by decide) (by decide), + h_v5_unc 3 h3lt (by decide) (by decide), + h_v4_unc 3 h3lt (by decide) (by decide), + h_v3_unc 3 h3lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_j + -- Lane 4: touched in step 3 as i-lane. + · have h_eq : v8.elements.val[4]! = v3.elements.val[4]! := by + rw [h_v8_unc 4 h4lt (by decide) (by decide), + h_v7_unc 4 h4lt (by decide) (by decide), + h_v6_unc 4 h4lt (by decide) (by decide), + h_v5_unc 4 h4lt (by decide) (by decide), + h_v4_unc 4 h4lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_i + -- Lane 5: touched in step 4 as i-lane. + · have h_eq : v8.elements.val[5]! = v4.elements.val[5]! := by + rw [h_v8_unc 5 h5lt (by decide) (by decide), + h_v7_unc 5 h5lt (by decide) (by decide), + h_v6_unc 5 h5lt (by decide) (by decide), + h_v5_unc 5 h5lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_i + -- Lane 6: touched in step 3 as j-lane. + · have h_eq : v8.elements.val[6]! = v3.elements.val[6]! := by + rw [h_v8_unc 6 h6lt (by decide) (by decide), + h_v7_unc 6 h6lt (by decide) (by decide), + h_v6_unc 6 h6lt (by decide) (by decide), + h_v5_unc 6 h6lt (by decide) (by decide), + h_v4_unc 6 h6lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_j + -- Lane 7: touched in step 4 as j-lane. + · have h_eq : v8.elements.val[7]! = v4.elements.val[7]! := by + rw [h_v8_unc 7 h7lt (by decide) (by decide), + h_v7_unc 7 h7lt (by decide) (by decide), + h_v6_unc 7 h7lt (by decide) (by decide), + h_v5_unc 7 h7lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_j + -- Lane 8: touched in step 5 as i-lane. + · have h_eq : v8.elements.val[8]! = v5.elements.val[8]! := by + rw [h_v8_unc 8 h8lt (by decide) (by decide), + h_v7_unc 8 h8lt (by decide) (by decide), + h_v6_unc 8 h8lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_i + -- Lane 9: touched in step 6 as i-lane. + · have h_eq : v8.elements.val[9]! = v6.elements.val[9]! := by + rw [h_v8_unc 9 h9lt (by decide) (by decide), + h_v7_unc 9 h9lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_i + -- Lane 10: touched in step 5 as j-lane. + · have h_eq : v8.elements.val[10]! = v5.elements.val[10]! := by + rw [h_v8_unc 10 h10lt (by decide) (by decide), + h_v7_unc 10 h10lt (by decide) (by decide), + h_v6_unc 10 h10lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_j + -- Lane 11: touched in step 6 as j-lane. + · have h_eq : v8.elements.val[11]! = v6.elements.val[11]! := by + rw [h_v8_unc 11 h11lt (by decide) (by decide), + h_v7_unc 11 h11lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_j + -- Lane 12: touched in step 7 as i-lane. + · have h_eq : v8.elements.val[12]! = v7.elements.val[12]! := by + rw [h_v8_unc 12 h12lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_i + -- Lane 13: touched in step 8 as i-lane. + · exact h_v8_i + -- Lane 14: touched in step 7 as j-lane. + · have h_eq : v8.elements.val[14]! = v7.elements.val[14]! := by + rw [h_v8_unc 14 h14lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_j + -- Lane 15: touched in step 8 as j-lane. + · exact h_v8_j + +/-! ## L2.3 — `ntt_layer_2_step_spec` + + Eight disjoint butterflies on pairs `(0,4)ζ0`, `(1,5)ζ0`, `(2,6)ζ0`, + `(3,7)ζ0`, `(8,12)ζ1`, `(9,13)ζ1`, `(10,14)ζ1`, `(11,15)ζ1`. Same + bookkeeping pattern as L2.2 but with `B = 6`. +-/ + +@[spec] +theorem ntt_layer_2_step_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 : Std.I16) + (hz0 : zeta0.val.natAbs ≤ 1664) (hz1 : zeta1.val.natAbs ≤ 1664) + (hpre : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 6 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step vec zeta0 zeta1 + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ 7 * 3328 ⌝ ⦄ := by + have h0lt : (0#usize : Std.Usize).val < 16 := by decide + have h1lt : (1#usize : Std.Usize).val < 16 := by decide + have h2lt : (2#usize : Std.Usize).val < 16 := by decide + have h3lt : (3#usize : Std.Usize).val < 16 := by decide + have h4lt : (4#usize : Std.Usize).val < 16 := by decide + have h5lt : (5#usize : Std.Usize).val < 16 := by decide + have h6lt : (6#usize : Std.Usize).val < 16 := by decide + have h7lt : (7#usize : Std.Usize).val < 16 := by decide + have h8lt : (8#usize : Std.Usize).val < 16 := by decide + have h9lt : (9#usize : Std.Usize).val < 16 := by decide + have h10lt : (10#usize : Std.Usize).val < 16 := by decide + have h11lt : (11#usize : Std.Usize).val < 16 := by decide + have h12lt : (12#usize : Std.Usize).val < 16 := by decide + have h13lt : (13#usize : Std.Usize).val < 16 := by decide + have h14lt : (14#usize : Std.Usize).val < 16 := by decide + have h15lt : (15#usize : Std.Usize).val < 16 := by decide + have hb_init : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 6 * 3328 := hpre + -- Step 1: (0, 4) ζ0. + obtain ⟨v1, h_v1_eq, h_v1_unc, h_v1_i, h_v1_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B vec zeta0 0#usize 4#usize 6 + h0lt h4lt (by decide) hz0 (hb_init 0 h0lt) (hb_init 4 h4lt) (by decide)) + -- Step 2: (1, 5) ζ0. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v1_unc 1 h1lt (by decide) (by decide)]; exact hb_init 1 h1lt + have h_v1_5 : (v1.elements.val[5]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v1_unc 5 h5lt (by decide) (by decide)]; exact hb_init 5 h5lt + obtain ⟨v2, h_v2_eq, h_v2_unc, h_v2_i, h_v2_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v1 zeta0 1#usize 5#usize 6 + h1lt h5lt (by decide) hz0 h_v1_1 h_v1_5 (by decide)) + -- Step 3: (2, 6) ζ0. + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v2_unc 2 h2lt (by decide) (by decide), h_v1_unc 2 h2lt (by decide) (by decide)] + exact hb_init 2 h2lt + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v2_unc 6 h6lt (by decide) (by decide), h_v1_unc 6 h6lt (by decide) (by decide)] + exact hb_init 6 h6lt + obtain ⟨v3, h_v3_eq, h_v3_unc, h_v3_i, h_v3_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v2 zeta0 2#usize 6#usize 6 + h2lt h6lt (by decide) hz0 h_v2_2 h_v2_6 (by decide)) + -- Step 4: (3, 7) ζ0. + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v3_unc 3 h3lt (by decide) (by decide), + h_v2_unc 3 h3lt (by decide) (by decide), + h_v1_unc 3 h3lt (by decide) (by decide)] + exact hb_init 3 h3lt + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v3_unc 7 h7lt (by decide) (by decide), + h_v2_unc 7 h7lt (by decide) (by decide), + h_v1_unc 7 h7lt (by decide) (by decide)] + exact hb_init 7 h7lt + obtain ⟨v4, h_v4_eq, h_v4_unc, h_v4_i, h_v4_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v3 zeta0 3#usize 7#usize 6 + h3lt h7lt (by decide) hz0 h_v3_3 h_v3_7 (by decide)) + -- Step 5: (8, 12) ζ1. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v4_unc 8 h8lt (by decide) (by decide), + h_v3_unc 8 h8lt (by decide) (by decide), + h_v2_unc 8 h8lt (by decide) (by decide), + h_v1_unc 8 h8lt (by decide) (by decide)] + exact hb_init 8 h8lt + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v4_unc 12 h12lt (by decide) (by decide), + h_v3_unc 12 h12lt (by decide) (by decide), + h_v2_unc 12 h12lt (by decide) (by decide), + h_v1_unc 12 h12lt (by decide) (by decide)] + exact hb_init 12 h12lt + obtain ⟨v5, h_v5_eq, h_v5_unc, h_v5_i, h_v5_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v4 zeta1 8#usize 12#usize 6 + h8lt h12lt (by decide) hz1 h_v4_8 h_v4_12 (by decide)) + -- Step 6: (9, 13) ζ1. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v5_unc 9 h9lt (by decide) (by decide), + h_v4_unc 9 h9lt (by decide) (by decide), + h_v3_unc 9 h9lt (by decide) (by decide), + h_v2_unc 9 h9lt (by decide) (by decide), + h_v1_unc 9 h9lt (by decide) (by decide)] + exact hb_init 9 h9lt + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v5_unc 13 h13lt (by decide) (by decide), + h_v4_unc 13 h13lt (by decide) (by decide), + h_v3_unc 13 h13lt (by decide) (by decide), + h_v2_unc 13 h13lt (by decide) (by decide), + h_v1_unc 13 h13lt (by decide) (by decide)] + exact hb_init 13 h13lt + obtain ⟨v6, h_v6_eq, h_v6_unc, h_v6_i, h_v6_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v5 zeta1 9#usize 13#usize 6 + h9lt h13lt (by decide) hz1 h_v5_9 h_v5_13 (by decide)) + -- Step 7: (10, 14) ζ1. + have h_v6_10 : (v6.elements.val[10]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v6_unc 10 h10lt (by decide) (by decide), + h_v5_unc 10 h10lt (by decide) (by decide), + h_v4_unc 10 h10lt (by decide) (by decide), + h_v3_unc 10 h10lt (by decide) (by decide), + h_v2_unc 10 h10lt (by decide) (by decide), + h_v1_unc 10 h10lt (by decide) (by decide)] + exact hb_init 10 h10lt + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v6_unc 14 h14lt (by decide) (by decide), + h_v5_unc 14 h14lt (by decide) (by decide), + h_v4_unc 14 h14lt (by decide) (by decide), + h_v3_unc 14 h14lt (by decide) (by decide), + h_v2_unc 14 h14lt (by decide) (by decide), + h_v1_unc 14 h14lt (by decide) (by decide)] + exact hb_init 14 h14lt + obtain ⟨v7, h_v7_eq, h_v7_unc, h_v7_i, h_v7_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v6 zeta1 10#usize 14#usize 6 + h10lt h14lt (by decide) hz1 h_v6_10 h_v6_14 (by decide)) + -- Step 8: (11, 15) ζ1. + have h_v7_11 : (v7.elements.val[11]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v7_unc 11 h11lt (by decide) (by decide), + h_v6_unc 11 h11lt (by decide) (by decide), + h_v5_unc 11 h11lt (by decide) (by decide), + h_v4_unc 11 h11lt (by decide) (by decide), + h_v3_unc 11 h11lt (by decide) (by decide), + h_v2_unc 11 h11lt (by decide) (by decide), + h_v1_unc 11 h11lt (by decide) (by decide)] + exact hb_init 11 h11lt + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 6 * 3328 := by + rw [h_v7_unc 15 h15lt (by decide) (by decide), + h_v6_unc 15 h15lt (by decide) (by decide), + h_v5_unc 15 h15lt (by decide) (by decide), + h_v4_unc 15 h15lt (by decide) (by decide), + h_v3_unc 15 h15lt (by decide) (by decide), + h_v2_unc 15 h15lt (by decide) (by decide), + h_v1_unc 15 h15lt (by decide) (by decide)] + exact hb_init 15 h15lt + obtain ⟨v8, h_v8_eq, h_v8_unc, h_v8_i, h_v8_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v7 zeta1 11#usize 15#usize 6 + h11lt h15lt (by decide) hz1 h_v7_11 h_v7_15 (by decide)) + -- Compose into one `.ok v8` equation. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step vec zeta0 zeta1 + = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step + rw [h_v1_eq]; simp only [bind_tc_ok] + rw [h_v2_eq]; simp only [bind_tc_ok] + rw [h_v3_eq]; simp only [bind_tc_ok] + rw [h_v4_eq]; simp only [bind_tc_ok] + rw [h_v5_eq]; simp only [bind_tc_ok] + rw [h_v6_eq]; simp only [bind_tc_ok] + rw [h_v7_eq]; simp only [bind_tc_ok] + exact h_v8_eq + -- Close: per-lane case split. + apply triple_of_ok_l2 h_body + intro i hi + interval_cases i + -- Lane 0: step 1 i-lane. + · have h_eq : v8.elements.val[0]! = v1.elements.val[0]! := by + rw [h_v8_unc 0 h0lt (by decide) (by decide), + h_v7_unc 0 h0lt (by decide) (by decide), + h_v6_unc 0 h0lt (by decide) (by decide), + h_v5_unc 0 h0lt (by decide) (by decide), + h_v4_unc 0 h0lt (by decide) (by decide), + h_v3_unc 0 h0lt (by decide) (by decide), + h_v2_unc 0 h0lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_i + -- Lane 1: step 2 i-lane. + · have h_eq : v8.elements.val[1]! = v2.elements.val[1]! := by + rw [h_v8_unc 1 h1lt (by decide) (by decide), + h_v7_unc 1 h1lt (by decide) (by decide), + h_v6_unc 1 h1lt (by decide) (by decide), + h_v5_unc 1 h1lt (by decide) (by decide), + h_v4_unc 1 h1lt (by decide) (by decide), + h_v3_unc 1 h1lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_i + -- Lane 2: step 3 i-lane. + · have h_eq : v8.elements.val[2]! = v3.elements.val[2]! := by + rw [h_v8_unc 2 h2lt (by decide) (by decide), + h_v7_unc 2 h2lt (by decide) (by decide), + h_v6_unc 2 h2lt (by decide) (by decide), + h_v5_unc 2 h2lt (by decide) (by decide), + h_v4_unc 2 h2lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_i + -- Lane 3: step 4 i-lane. + · have h_eq : v8.elements.val[3]! = v4.elements.val[3]! := by + rw [h_v8_unc 3 h3lt (by decide) (by decide), + h_v7_unc 3 h3lt (by decide) (by decide), + h_v6_unc 3 h3lt (by decide) (by decide), + h_v5_unc 3 h3lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_i + -- Lane 4: step 1 j-lane. + · have h_eq : v8.elements.val[4]! = v1.elements.val[4]! := by + rw [h_v8_unc 4 h4lt (by decide) (by decide), + h_v7_unc 4 h4lt (by decide) (by decide), + h_v6_unc 4 h4lt (by decide) (by decide), + h_v5_unc 4 h4lt (by decide) (by decide), + h_v4_unc 4 h4lt (by decide) (by decide), + h_v3_unc 4 h4lt (by decide) (by decide), + h_v2_unc 4 h4lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_j + -- Lane 5: step 2 j-lane. + · have h_eq : v8.elements.val[5]! = v2.elements.val[5]! := by + rw [h_v8_unc 5 h5lt (by decide) (by decide), + h_v7_unc 5 h5lt (by decide) (by decide), + h_v6_unc 5 h5lt (by decide) (by decide), + h_v5_unc 5 h5lt (by decide) (by decide), + h_v4_unc 5 h5lt (by decide) (by decide), + h_v3_unc 5 h5lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_j + -- Lane 6: step 3 j-lane. + · have h_eq : v8.elements.val[6]! = v3.elements.val[6]! := by + rw [h_v8_unc 6 h6lt (by decide) (by decide), + h_v7_unc 6 h6lt (by decide) (by decide), + h_v6_unc 6 h6lt (by decide) (by decide), + h_v5_unc 6 h6lt (by decide) (by decide), + h_v4_unc 6 h6lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_j + -- Lane 7: step 4 j-lane. + · have h_eq : v8.elements.val[7]! = v4.elements.val[7]! := by + rw [h_v8_unc 7 h7lt (by decide) (by decide), + h_v7_unc 7 h7lt (by decide) (by decide), + h_v6_unc 7 h7lt (by decide) (by decide), + h_v5_unc 7 h7lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_j + -- Lane 8: step 5 i-lane. + · have h_eq : v8.elements.val[8]! = v5.elements.val[8]! := by + rw [h_v8_unc 8 h8lt (by decide) (by decide), + h_v7_unc 8 h8lt (by decide) (by decide), + h_v6_unc 8 h8lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_i + -- Lane 9: step 6 i-lane. + · have h_eq : v8.elements.val[9]! = v6.elements.val[9]! := by + rw [h_v8_unc 9 h9lt (by decide) (by decide), + h_v7_unc 9 h9lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_i + -- Lane 10: step 7 i-lane. + · have h_eq : v8.elements.val[10]! = v7.elements.val[10]! := by + rw [h_v8_unc 10 h10lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_i + -- Lane 11: step 8 i-lane. + · exact h_v8_i + -- Lane 12: step 5 j-lane. + · have h_eq : v8.elements.val[12]! = v5.elements.val[12]! := by + rw [h_v8_unc 12 h12lt (by decide) (by decide), + h_v7_unc 12 h12lt (by decide) (by decide), + h_v6_unc 12 h12lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_j + -- Lane 13: step 6 j-lane. + · have h_eq : v8.elements.val[13]! = v6.elements.val[13]! := by + rw [h_v8_unc 13 h13lt (by decide) (by decide), + h_v7_unc 13 h13lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_j + -- Lane 14: step 7 j-lane. + · have h_eq : v8.elements.val[14]! = v7.elements.val[14]! := by + rw [h_v8_unc 14 h14lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_j + -- Lane 15: step 8 j-lane. + · exact h_v8_j + +/-! ## L2.3.bnd — `ntt_layer_2_step_spec_bnd` + + Nat-bnd parameterised mirror of `ntt_layer_2_step_spec` (L2.3): same + eight disjoint butterflies on pairs `(0,4)ζ0`, `(1,5)ζ0`, `(2,6)ζ0`, + `(3,7)ζ0`, `(8,12)ζ1`, `(9,13)ζ1`, `(10,14)ζ1`, `(11,15)ζ1`, dispatched + via the `_bnd` form. Same `bnd ≤ 8 * 3328` precondition as + `ntt_layer_1_step_spec_bnd`. -/ + +@[spec] +theorem ntt_layer_2_step_spec_bnd + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 : Std.I16) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (hz0 : zeta0.val.natAbs ≤ 1664) (hz1 : zeta1.val.natAbs ≤ 1664) + (hpre : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step vec zeta0 zeta1 + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + have h0lt : (0#usize : Std.Usize).val < 16 := by decide + have h1lt : (1#usize : Std.Usize).val < 16 := by decide + have h2lt : (2#usize : Std.Usize).val < 16 := by decide + have h3lt : (3#usize : Std.Usize).val < 16 := by decide + have h4lt : (4#usize : Std.Usize).val < 16 := by decide + have h5lt : (5#usize : Std.Usize).val < 16 := by decide + have h6lt : (6#usize : Std.Usize).val < 16 := by decide + have h7lt : (7#usize : Std.Usize).val < 16 := by decide + have h8lt : (8#usize : Std.Usize).val < 16 := by decide + have h9lt : (9#usize : Std.Usize).val < 16 := by decide + have h10lt : (10#usize : Std.Usize).val < 16 := by decide + have h11lt : (11#usize : Std.Usize).val < 16 := by decide + have h12lt : (12#usize : Std.Usize).val < 16 := by decide + have h13lt : (13#usize : Std.Usize).val < 16 := by decide + have h14lt : (14#usize : Std.Usize).val < 16 := by decide + have h15lt : (15#usize : Std.Usize).val < 16 := by decide + -- Bridge: bnd ≤ 8*3328 = 26624 ≤ 29439 (= ntt_step_spec_bnd's max input bound). + have h_bnd29439 : bnd ≤ 29439 := by omega + have hb_init : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ bnd := hpre + -- Step 1: (0, 4) ζ0. + obtain ⟨v1, h_v1_eq, h_v1_unc, h_v1_i, h_v1_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd vec zeta0 0#usize 4#usize bnd + h0lt h4lt (by decide) hz0 (hb_init 0 h0lt) (hb_init 4 h4lt) h_bnd29439) + -- Step 2: (1, 5) ζ0. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ bnd := by + rw [h_v1_unc 1 h1lt (by decide) (by decide)]; exact hb_init 1 h1lt + have h_v1_5 : (v1.elements.val[5]!).val.natAbs ≤ bnd := by + rw [h_v1_unc 5 h5lt (by decide) (by decide)]; exact hb_init 5 h5lt + obtain ⟨v2, h_v2_eq, h_v2_unc, h_v2_i, h_v2_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v1 zeta0 1#usize 5#usize bnd + h1lt h5lt (by decide) hz0 h_v1_1 h_v1_5 h_bnd29439) + -- Step 3: (2, 6) ζ0. + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ bnd := by + rw [h_v2_unc 2 h2lt (by decide) (by decide), h_v1_unc 2 h2lt (by decide) (by decide)] + exact hb_init 2 h2lt + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ bnd := by + rw [h_v2_unc 6 h6lt (by decide) (by decide), h_v1_unc 6 h6lt (by decide) (by decide)] + exact hb_init 6 h6lt + obtain ⟨v3, h_v3_eq, h_v3_unc, h_v3_i, h_v3_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v2 zeta0 2#usize 6#usize bnd + h2lt h6lt (by decide) hz0 h_v2_2 h_v2_6 h_bnd29439) + -- Step 4: (3, 7) ζ0. + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ bnd := by + rw [h_v3_unc 3 h3lt (by decide) (by decide), + h_v2_unc 3 h3lt (by decide) (by decide), + h_v1_unc 3 h3lt (by decide) (by decide)] + exact hb_init 3 h3lt + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ bnd := by + rw [h_v3_unc 7 h7lt (by decide) (by decide), + h_v2_unc 7 h7lt (by decide) (by decide), + h_v1_unc 7 h7lt (by decide) (by decide)] + exact hb_init 7 h7lt + obtain ⟨v4, h_v4_eq, h_v4_unc, h_v4_i, h_v4_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v3 zeta0 3#usize 7#usize bnd + h3lt h7lt (by decide) hz0 h_v3_3 h_v3_7 h_bnd29439) + -- Step 5: (8, 12) ζ1. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ bnd := by + rw [h_v4_unc 8 h8lt (by decide) (by decide), + h_v3_unc 8 h8lt (by decide) (by decide), + h_v2_unc 8 h8lt (by decide) (by decide), + h_v1_unc 8 h8lt (by decide) (by decide)] + exact hb_init 8 h8lt + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ bnd := by + rw [h_v4_unc 12 h12lt (by decide) (by decide), + h_v3_unc 12 h12lt (by decide) (by decide), + h_v2_unc 12 h12lt (by decide) (by decide), + h_v1_unc 12 h12lt (by decide) (by decide)] + exact hb_init 12 h12lt + obtain ⟨v5, h_v5_eq, h_v5_unc, h_v5_i, h_v5_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v4 zeta1 8#usize 12#usize bnd + h8lt h12lt (by decide) hz1 h_v4_8 h_v4_12 h_bnd29439) + -- Step 6: (9, 13) ζ1. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ bnd := by + rw [h_v5_unc 9 h9lt (by decide) (by decide), + h_v4_unc 9 h9lt (by decide) (by decide), + h_v3_unc 9 h9lt (by decide) (by decide), + h_v2_unc 9 h9lt (by decide) (by decide), + h_v1_unc 9 h9lt (by decide) (by decide)] + exact hb_init 9 h9lt + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ bnd := by + rw [h_v5_unc 13 h13lt (by decide) (by decide), + h_v4_unc 13 h13lt (by decide) (by decide), + h_v3_unc 13 h13lt (by decide) (by decide), + h_v2_unc 13 h13lt (by decide) (by decide), + h_v1_unc 13 h13lt (by decide) (by decide)] + exact hb_init 13 h13lt + obtain ⟨v6, h_v6_eq, h_v6_unc, h_v6_i, h_v6_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v5 zeta1 9#usize 13#usize bnd + h9lt h13lt (by decide) hz1 h_v5_9 h_v5_13 h_bnd29439) + -- Step 7: (10, 14) ζ1. + have h_v6_10 : (v6.elements.val[10]!).val.natAbs ≤ bnd := by + rw [h_v6_unc 10 h10lt (by decide) (by decide), + h_v5_unc 10 h10lt (by decide) (by decide), + h_v4_unc 10 h10lt (by decide) (by decide), + h_v3_unc 10 h10lt (by decide) (by decide), + h_v2_unc 10 h10lt (by decide) (by decide), + h_v1_unc 10 h10lt (by decide) (by decide)] + exact hb_init 10 h10lt + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ bnd := by + rw [h_v6_unc 14 h14lt (by decide) (by decide), + h_v5_unc 14 h14lt (by decide) (by decide), + h_v4_unc 14 h14lt (by decide) (by decide), + h_v3_unc 14 h14lt (by decide) (by decide), + h_v2_unc 14 h14lt (by decide) (by decide), + h_v1_unc 14 h14lt (by decide) (by decide)] + exact hb_init 14 h14lt + obtain ⟨v7, h_v7_eq, h_v7_unc, h_v7_i, h_v7_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v6 zeta1 10#usize 14#usize bnd + h10lt h14lt (by decide) hz1 h_v6_10 h_v6_14 h_bnd29439) + -- Step 8: (11, 15) ζ1. + have h_v7_11 : (v7.elements.val[11]!).val.natAbs ≤ bnd := by + rw [h_v7_unc 11 h11lt (by decide) (by decide), + h_v6_unc 11 h11lt (by decide) (by decide), + h_v5_unc 11 h11lt (by decide) (by decide), + h_v4_unc 11 h11lt (by decide) (by decide), + h_v3_unc 11 h11lt (by decide) (by decide), + h_v2_unc 11 h11lt (by decide) (by decide), + h_v1_unc 11 h11lt (by decide) (by decide)] + exact hb_init 11 h11lt + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ bnd := by + rw [h_v7_unc 15 h15lt (by decide) (by decide), + h_v6_unc 15 h15lt (by decide) (by decide), + h_v5_unc 15 h15lt (by decide) (by decide), + h_v4_unc 15 h15lt (by decide) (by decide), + h_v3_unc 15 h15lt (by decide) (by decide), + h_v2_unc 15 h15lt (by decide) (by decide), + h_v1_unc 15 h15lt (by decide) (by decide)] + exact hb_init 15 h15lt + obtain ⟨v8, h_v8_eq, h_v8_unc, h_v8_i, h_v8_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v7 zeta1 11#usize 15#usize bnd + h11lt h15lt (by decide) hz1 h_v7_11 h_v7_15 h_bnd29439) + -- Compose into one `.ok v8` equation. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step vec zeta0 zeta1 + = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step + rw [h_v1_eq]; simp only [bind_tc_ok] + rw [h_v2_eq]; simp only [bind_tc_ok] + rw [h_v3_eq]; simp only [bind_tc_ok] + rw [h_v4_eq]; simp only [bind_tc_ok] + rw [h_v5_eq]; simp only [bind_tc_ok] + rw [h_v6_eq]; simp only [bind_tc_ok] + rw [h_v7_eq]; simp only [bind_tc_ok] + exact h_v8_eq + -- Close: per-lane case split. + apply triple_of_ok_l2 h_body + intro i hi + interval_cases i + -- Lane 0: step 1 i-lane. + · have h_eq : v8.elements.val[0]! = v1.elements.val[0]! := by + rw [h_v8_unc 0 h0lt (by decide) (by decide), + h_v7_unc 0 h0lt (by decide) (by decide), + h_v6_unc 0 h0lt (by decide) (by decide), + h_v5_unc 0 h0lt (by decide) (by decide), + h_v4_unc 0 h0lt (by decide) (by decide), + h_v3_unc 0 h0lt (by decide) (by decide), + h_v2_unc 0 h0lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_i + -- Lane 1: step 2 i-lane. + · have h_eq : v8.elements.val[1]! = v2.elements.val[1]! := by + rw [h_v8_unc 1 h1lt (by decide) (by decide), + h_v7_unc 1 h1lt (by decide) (by decide), + h_v6_unc 1 h1lt (by decide) (by decide), + h_v5_unc 1 h1lt (by decide) (by decide), + h_v4_unc 1 h1lt (by decide) (by decide), + h_v3_unc 1 h1lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_i + -- Lane 2: step 3 i-lane. + · have h_eq : v8.elements.val[2]! = v3.elements.val[2]! := by + rw [h_v8_unc 2 h2lt (by decide) (by decide), + h_v7_unc 2 h2lt (by decide) (by decide), + h_v6_unc 2 h2lt (by decide) (by decide), + h_v5_unc 2 h2lt (by decide) (by decide), + h_v4_unc 2 h2lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_i + -- Lane 3: step 4 i-lane. + · have h_eq : v8.elements.val[3]! = v4.elements.val[3]! := by + rw [h_v8_unc 3 h3lt (by decide) (by decide), + h_v7_unc 3 h3lt (by decide) (by decide), + h_v6_unc 3 h3lt (by decide) (by decide), + h_v5_unc 3 h3lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_i + -- Lane 4: step 1 j-lane. + · have h_eq : v8.elements.val[4]! = v1.elements.val[4]! := by + rw [h_v8_unc 4 h4lt (by decide) (by decide), + h_v7_unc 4 h4lt (by decide) (by decide), + h_v6_unc 4 h4lt (by decide) (by decide), + h_v5_unc 4 h4lt (by decide) (by decide), + h_v4_unc 4 h4lt (by decide) (by decide), + h_v3_unc 4 h4lt (by decide) (by decide), + h_v2_unc 4 h4lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_j + -- Lane 5: step 2 j-lane. + · have h_eq : v8.elements.val[5]! = v2.elements.val[5]! := by + rw [h_v8_unc 5 h5lt (by decide) (by decide), + h_v7_unc 5 h5lt (by decide) (by decide), + h_v6_unc 5 h5lt (by decide) (by decide), + h_v5_unc 5 h5lt (by decide) (by decide), + h_v4_unc 5 h5lt (by decide) (by decide), + h_v3_unc 5 h5lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_j + -- Lane 6: step 3 j-lane. + · have h_eq : v8.elements.val[6]! = v3.elements.val[6]! := by + rw [h_v8_unc 6 h6lt (by decide) (by decide), + h_v7_unc 6 h6lt (by decide) (by decide), + h_v6_unc 6 h6lt (by decide) (by decide), + h_v5_unc 6 h6lt (by decide) (by decide), + h_v4_unc 6 h6lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_j + -- Lane 7: step 4 j-lane. + · have h_eq : v8.elements.val[7]! = v4.elements.val[7]! := by + rw [h_v8_unc 7 h7lt (by decide) (by decide), + h_v7_unc 7 h7lt (by decide) (by decide), + h_v6_unc 7 h7lt (by decide) (by decide), + h_v5_unc 7 h7lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_j + -- Lane 8: step 5 i-lane. + · have h_eq : v8.elements.val[8]! = v5.elements.val[8]! := by + rw [h_v8_unc 8 h8lt (by decide) (by decide), + h_v7_unc 8 h8lt (by decide) (by decide), + h_v6_unc 8 h8lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_i + -- Lane 9: step 6 i-lane. + · have h_eq : v8.elements.val[9]! = v6.elements.val[9]! := by + rw [h_v8_unc 9 h9lt (by decide) (by decide), + h_v7_unc 9 h9lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_i + -- Lane 10: step 7 i-lane. + · have h_eq : v8.elements.val[10]! = v7.elements.val[10]! := by + rw [h_v8_unc 10 h10lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_i + -- Lane 11: step 8 i-lane. + · exact h_v8_i + -- Lane 12: step 5 j-lane. + · have h_eq : v8.elements.val[12]! = v5.elements.val[12]! := by + rw [h_v8_unc 12 h12lt (by decide) (by decide), + h_v7_unc 12 h12lt (by decide) (by decide), + h_v6_unc 12 h12lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_j + -- Lane 13: step 6 j-lane. + · have h_eq : v8.elements.val[13]! = v6.elements.val[13]! := by + rw [h_v8_unc 13 h13lt (by decide) (by decide), + h_v7_unc 13 h13lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_j + -- Lane 14: step 7 j-lane. + · have h_eq : v8.elements.val[14]! = v7.elements.val[14]! := by + rw [h_v8_unc 14 h14lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_j + -- Lane 15: step 8 j-lane. + · exact h_v8_j + +/-! ## L2.4 — `ntt_layer_3_step_spec` + + Eight disjoint butterflies on pairs `(0,8)`, `(1,9)`, `(2,10)`, `(3,11)`, + `(4,12)`, `(5,13)`, `(6,14)`, `(7,15)` — all with the single ζ. Same + bookkeeping pattern as L2.2/L2.3 but with `B = 5`. +-/ + +@[spec] +theorem ntt_layer_3_step_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (hz : zeta.val.natAbs ≤ 1664) + (hpre : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 5 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step vec zeta + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ 6 * 3328 ⌝ ⦄ := by + have h0lt : (0#usize : Std.Usize).val < 16 := by decide + have h1lt : (1#usize : Std.Usize).val < 16 := by decide + have h2lt : (2#usize : Std.Usize).val < 16 := by decide + have h3lt : (3#usize : Std.Usize).val < 16 := by decide + have h4lt : (4#usize : Std.Usize).val < 16 := by decide + have h5lt : (5#usize : Std.Usize).val < 16 := by decide + have h6lt : (6#usize : Std.Usize).val < 16 := by decide + have h7lt : (7#usize : Std.Usize).val < 16 := by decide + have h8lt : (8#usize : Std.Usize).val < 16 := by decide + have h9lt : (9#usize : Std.Usize).val < 16 := by decide + have h10lt : (10#usize : Std.Usize).val < 16 := by decide + have h11lt : (11#usize : Std.Usize).val < 16 := by decide + have h12lt : (12#usize : Std.Usize).val < 16 := by decide + have h13lt : (13#usize : Std.Usize).val < 16 := by decide + have h14lt : (14#usize : Std.Usize).val < 16 := by decide + have h15lt : (15#usize : Std.Usize).val < 16 := by decide + have hb_init : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 5 * 3328 := hpre + -- Step 1: (0, 8). + obtain ⟨v1, h_v1_eq, h_v1_unc, h_v1_i, h_v1_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B vec zeta 0#usize 8#usize 5 + h0lt h8lt (by decide) hz (hb_init 0 h0lt) (hb_init 8 h8lt) (by decide)) + -- Step 2: (1, 9). + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v1_unc 1 h1lt (by decide) (by decide)]; exact hb_init 1 h1lt + have h_v1_9 : (v1.elements.val[9]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v1_unc 9 h9lt (by decide) (by decide)]; exact hb_init 9 h9lt + obtain ⟨v2, h_v2_eq, h_v2_unc, h_v2_i, h_v2_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v1 zeta 1#usize 9#usize 5 + h1lt h9lt (by decide) hz h_v1_1 h_v1_9 (by decide)) + -- Step 3: (2, 10). + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v2_unc 2 h2lt (by decide) (by decide), h_v1_unc 2 h2lt (by decide) (by decide)] + exact hb_init 2 h2lt + have h_v2_10 : (v2.elements.val[10]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v2_unc 10 h10lt (by decide) (by decide), h_v1_unc 10 h10lt (by decide) (by decide)] + exact hb_init 10 h10lt + obtain ⟨v3, h_v3_eq, h_v3_unc, h_v3_i, h_v3_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v2 zeta 2#usize 10#usize 5 + h2lt h10lt (by decide) hz h_v2_2 h_v2_10 (by decide)) + -- Step 4: (3, 11). + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v3_unc 3 h3lt (by decide) (by decide), + h_v2_unc 3 h3lt (by decide) (by decide), + h_v1_unc 3 h3lt (by decide) (by decide)] + exact hb_init 3 h3lt + have h_v3_11 : (v3.elements.val[11]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v3_unc 11 h11lt (by decide) (by decide), + h_v2_unc 11 h11lt (by decide) (by decide), + h_v1_unc 11 h11lt (by decide) (by decide)] + exact hb_init 11 h11lt + obtain ⟨v4, h_v4_eq, h_v4_unc, h_v4_i, h_v4_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v3 zeta 3#usize 11#usize 5 + h3lt h11lt (by decide) hz h_v3_3 h_v3_11 (by decide)) + -- Step 5: (4, 12). + have h_v4_4 : (v4.elements.val[4]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v4_unc 4 h4lt (by decide) (by decide), + h_v3_unc 4 h4lt (by decide) (by decide), + h_v2_unc 4 h4lt (by decide) (by decide), + h_v1_unc 4 h4lt (by decide) (by decide)] + exact hb_init 4 h4lt + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v4_unc 12 h12lt (by decide) (by decide), + h_v3_unc 12 h12lt (by decide) (by decide), + h_v2_unc 12 h12lt (by decide) (by decide), + h_v1_unc 12 h12lt (by decide) (by decide)] + exact hb_init 12 h12lt + obtain ⟨v5, h_v5_eq, h_v5_unc, h_v5_i, h_v5_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v4 zeta 4#usize 12#usize 5 + h4lt h12lt (by decide) hz h_v4_4 h_v4_12 (by decide)) + -- Step 6: (5, 13). + have h_v5_5 : (v5.elements.val[5]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v5_unc 5 h5lt (by decide) (by decide), + h_v4_unc 5 h5lt (by decide) (by decide), + h_v3_unc 5 h5lt (by decide) (by decide), + h_v2_unc 5 h5lt (by decide) (by decide), + h_v1_unc 5 h5lt (by decide) (by decide)] + exact hb_init 5 h5lt + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v5_unc 13 h13lt (by decide) (by decide), + h_v4_unc 13 h13lt (by decide) (by decide), + h_v3_unc 13 h13lt (by decide) (by decide), + h_v2_unc 13 h13lt (by decide) (by decide), + h_v1_unc 13 h13lt (by decide) (by decide)] + exact hb_init 13 h13lt + obtain ⟨v6, h_v6_eq, h_v6_unc, h_v6_i, h_v6_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v5 zeta 5#usize 13#usize 5 + h5lt h13lt (by decide) hz h_v5_5 h_v5_13 (by decide)) + -- Step 7: (6, 14). + have h_v6_6 : (v6.elements.val[6]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v6_unc 6 h6lt (by decide) (by decide), + h_v5_unc 6 h6lt (by decide) (by decide), + h_v4_unc 6 h6lt (by decide) (by decide), + h_v3_unc 6 h6lt (by decide) (by decide), + h_v2_unc 6 h6lt (by decide) (by decide), + h_v1_unc 6 h6lt (by decide) (by decide)] + exact hb_init 6 h6lt + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v6_unc 14 h14lt (by decide) (by decide), + h_v5_unc 14 h14lt (by decide) (by decide), + h_v4_unc 14 h14lt (by decide) (by decide), + h_v3_unc 14 h14lt (by decide) (by decide), + h_v2_unc 14 h14lt (by decide) (by decide), + h_v1_unc 14 h14lt (by decide) (by decide)] + exact hb_init 14 h14lt + obtain ⟨v7, h_v7_eq, h_v7_unc, h_v7_i, h_v7_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v6 zeta 6#usize 14#usize 5 + h6lt h14lt (by decide) hz h_v6_6 h_v6_14 (by decide)) + -- Step 8: (7, 15). + have h_v7_7 : (v7.elements.val[7]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v7_unc 7 h7lt (by decide) (by decide), + h_v6_unc 7 h7lt (by decide) (by decide), + h_v5_unc 7 h7lt (by decide) (by decide), + h_v4_unc 7 h7lt (by decide) (by decide), + h_v3_unc 7 h7lt (by decide) (by decide), + h_v2_unc 7 h7lt (by decide) (by decide), + h_v1_unc 7 h7lt (by decide) (by decide)] + exact hb_init 7 h7lt + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 5 * 3328 := by + rw [h_v7_unc 15 h15lt (by decide) (by decide), + h_v6_unc 15 h15lt (by decide) (by decide), + h_v5_unc 15 h15lt (by decide) (by decide), + h_v4_unc 15 h15lt (by decide) (by decide), + h_v3_unc 15 h15lt (by decide) (by decide), + h_v2_unc 15 h15lt (by decide) (by decide), + h_v1_unc 15 h15lt (by decide) (by decide)] + exact hb_init 15 h15lt + obtain ⟨v8, h_v8_eq, h_v8_unc, h_v8_i, h_v8_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_B v7 zeta 7#usize 15#usize 5 + h7lt h15lt (by decide) hz h_v7_7 h_v7_15 (by decide)) + -- Compose into one `.ok v8` equation. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step vec zeta = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step + rw [h_v1_eq]; simp only [bind_tc_ok] + rw [h_v2_eq]; simp only [bind_tc_ok] + rw [h_v3_eq]; simp only [bind_tc_ok] + rw [h_v4_eq]; simp only [bind_tc_ok] + rw [h_v5_eq]; simp only [bind_tc_ok] + rw [h_v6_eq]; simp only [bind_tc_ok] + rw [h_v7_eq]; simp only [bind_tc_ok] + exact h_v8_eq + -- Close: per-lane case split. + apply triple_of_ok_l2 h_body + intro i hi + interval_cases i + -- Lane 0: step 1 i-lane. + · have h_eq : v8.elements.val[0]! = v1.elements.val[0]! := by + rw [h_v8_unc 0 h0lt (by decide) (by decide), + h_v7_unc 0 h0lt (by decide) (by decide), + h_v6_unc 0 h0lt (by decide) (by decide), + h_v5_unc 0 h0lt (by decide) (by decide), + h_v4_unc 0 h0lt (by decide) (by decide), + h_v3_unc 0 h0lt (by decide) (by decide), + h_v2_unc 0 h0lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_i + -- Lane 1: step 2 i-lane. + · have h_eq : v8.elements.val[1]! = v2.elements.val[1]! := by + rw [h_v8_unc 1 h1lt (by decide) (by decide), + h_v7_unc 1 h1lt (by decide) (by decide), + h_v6_unc 1 h1lt (by decide) (by decide), + h_v5_unc 1 h1lt (by decide) (by decide), + h_v4_unc 1 h1lt (by decide) (by decide), + h_v3_unc 1 h1lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_i + -- Lane 2: step 3 i-lane. + · have h_eq : v8.elements.val[2]! = v3.elements.val[2]! := by + rw [h_v8_unc 2 h2lt (by decide) (by decide), + h_v7_unc 2 h2lt (by decide) (by decide), + h_v6_unc 2 h2lt (by decide) (by decide), + h_v5_unc 2 h2lt (by decide) (by decide), + h_v4_unc 2 h2lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_i + -- Lane 3: step 4 i-lane. + · have h_eq : v8.elements.val[3]! = v4.elements.val[3]! := by + rw [h_v8_unc 3 h3lt (by decide) (by decide), + h_v7_unc 3 h3lt (by decide) (by decide), + h_v6_unc 3 h3lt (by decide) (by decide), + h_v5_unc 3 h3lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_i + -- Lane 4: step 5 i-lane. + · have h_eq : v8.elements.val[4]! = v5.elements.val[4]! := by + rw [h_v8_unc 4 h4lt (by decide) (by decide), + h_v7_unc 4 h4lt (by decide) (by decide), + h_v6_unc 4 h4lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_i + -- Lane 5: step 6 i-lane. + · have h_eq : v8.elements.val[5]! = v6.elements.val[5]! := by + rw [h_v8_unc 5 h5lt (by decide) (by decide), + h_v7_unc 5 h5lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_i + -- Lane 6: step 7 i-lane. + · have h_eq : v8.elements.val[6]! = v7.elements.val[6]! := by + rw [h_v8_unc 6 h6lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_i + -- Lane 7: step 8 i-lane. + · exact h_v8_i + -- Lane 8: step 1 j-lane. + · have h_eq : v8.elements.val[8]! = v1.elements.val[8]! := by + rw [h_v8_unc 8 h8lt (by decide) (by decide), + h_v7_unc 8 h8lt (by decide) (by decide), + h_v6_unc 8 h8lt (by decide) (by decide), + h_v5_unc 8 h8lt (by decide) (by decide), + h_v4_unc 8 h8lt (by decide) (by decide), + h_v3_unc 8 h8lt (by decide) (by decide), + h_v2_unc 8 h8lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_j + -- Lane 9: step 2 j-lane. + · have h_eq : v8.elements.val[9]! = v2.elements.val[9]! := by + rw [h_v8_unc 9 h9lt (by decide) (by decide), + h_v7_unc 9 h9lt (by decide) (by decide), + h_v6_unc 9 h9lt (by decide) (by decide), + h_v5_unc 9 h9lt (by decide) (by decide), + h_v4_unc 9 h9lt (by decide) (by decide), + h_v3_unc 9 h9lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_j + -- Lane 10: step 3 j-lane. + · have h_eq : v8.elements.val[10]! = v3.elements.val[10]! := by + rw [h_v8_unc 10 h10lt (by decide) (by decide), + h_v7_unc 10 h10lt (by decide) (by decide), + h_v6_unc 10 h10lt (by decide) (by decide), + h_v5_unc 10 h10lt (by decide) (by decide), + h_v4_unc 10 h10lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_j + -- Lane 11: step 4 j-lane. + · have h_eq : v8.elements.val[11]! = v4.elements.val[11]! := by + rw [h_v8_unc 11 h11lt (by decide) (by decide), + h_v7_unc 11 h11lt (by decide) (by decide), + h_v6_unc 11 h11lt (by decide) (by decide), + h_v5_unc 11 h11lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_j + -- Lane 12: step 5 j-lane. + · have h_eq : v8.elements.val[12]! = v5.elements.val[12]! := by + rw [h_v8_unc 12 h12lt (by decide) (by decide), + h_v7_unc 12 h12lt (by decide) (by decide), + h_v6_unc 12 h12lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_j + -- Lane 13: step 6 j-lane. + · have h_eq : v8.elements.val[13]! = v6.elements.val[13]! := by + rw [h_v8_unc 13 h13lt (by decide) (by decide), + h_v7_unc 13 h13lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_j + -- Lane 14: step 7 j-lane. + · have h_eq : v8.elements.val[14]! = v7.elements.val[14]! := by + rw [h_v8_unc 14 h14lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_j + -- Lane 15: step 8 j-lane. + · exact h_v8_j + +/-! ## L2.4.bnd — `ntt_layer_3_step_spec_bnd` + + Nat-bnd parameterised mirror of `ntt_layer_3_step_spec` (L2.4): same + eight disjoint butterflies on pairs `(0,8)`, `(1,9)`, `(2,10)`, `(3,11)`, + `(4,12)`, `(5,13)`, `(6,14)`, `(7,15)` — single ζ. Same + `bnd ≤ 8 * 3328` precondition as `ntt_layer_{1,2}_step_spec_bnd`. -/ + +@[spec] +theorem ntt_layer_3_step_spec_bnd + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) + (bnd : Nat) (h_bnd : bnd ≤ 29439) + (hz : zeta.val.natAbs ≤ 1664) + (hpre : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ bnd) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step vec zeta + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ bnd + 3328 ⌝ ⦄ := by + have h0lt : (0#usize : Std.Usize).val < 16 := by decide + have h1lt : (1#usize : Std.Usize).val < 16 := by decide + have h2lt : (2#usize : Std.Usize).val < 16 := by decide + have h3lt : (3#usize : Std.Usize).val < 16 := by decide + have h4lt : (4#usize : Std.Usize).val < 16 := by decide + have h5lt : (5#usize : Std.Usize).val < 16 := by decide + have h6lt : (6#usize : Std.Usize).val < 16 := by decide + have h7lt : (7#usize : Std.Usize).val < 16 := by decide + have h8lt : (8#usize : Std.Usize).val < 16 := by decide + have h9lt : (9#usize : Std.Usize).val < 16 := by decide + have h10lt : (10#usize : Std.Usize).val < 16 := by decide + have h11lt : (11#usize : Std.Usize).val < 16 := by decide + have h12lt : (12#usize : Std.Usize).val < 16 := by decide + have h13lt : (13#usize : Std.Usize).val < 16 := by decide + have h14lt : (14#usize : Std.Usize).val < 16 := by decide + have h15lt : (15#usize : Std.Usize).val < 16 := by decide + -- Bridge: bnd ≤ 8*3328 = 26624 ≤ 29439 (= ntt_step_spec_bnd's max input bound). + have h_bnd29439 : bnd ≤ 29439 := by omega + have hb_init : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ bnd := hpre + -- Step 1: (0, 8). + obtain ⟨v1, h_v1_eq, h_v1_unc, h_v1_i, h_v1_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd vec zeta 0#usize 8#usize bnd + h0lt h8lt (by decide) hz (hb_init 0 h0lt) (hb_init 8 h8lt) h_bnd29439) + -- Step 2: (1, 9). + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ bnd := by + rw [h_v1_unc 1 h1lt (by decide) (by decide)]; exact hb_init 1 h1lt + have h_v1_9 : (v1.elements.val[9]!).val.natAbs ≤ bnd := by + rw [h_v1_unc 9 h9lt (by decide) (by decide)]; exact hb_init 9 h9lt + obtain ⟨v2, h_v2_eq, h_v2_unc, h_v2_i, h_v2_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v1 zeta 1#usize 9#usize bnd + h1lt h9lt (by decide) hz h_v1_1 h_v1_9 h_bnd29439) + -- Step 3: (2, 10). + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ bnd := by + rw [h_v2_unc 2 h2lt (by decide) (by decide), h_v1_unc 2 h2lt (by decide) (by decide)] + exact hb_init 2 h2lt + have h_v2_10 : (v2.elements.val[10]!).val.natAbs ≤ bnd := by + rw [h_v2_unc 10 h10lt (by decide) (by decide), h_v1_unc 10 h10lt (by decide) (by decide)] + exact hb_init 10 h10lt + obtain ⟨v3, h_v3_eq, h_v3_unc, h_v3_i, h_v3_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v2 zeta 2#usize 10#usize bnd + h2lt h10lt (by decide) hz h_v2_2 h_v2_10 h_bnd29439) + -- Step 4: (3, 11). + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ bnd := by + rw [h_v3_unc 3 h3lt (by decide) (by decide), + h_v2_unc 3 h3lt (by decide) (by decide), + h_v1_unc 3 h3lt (by decide) (by decide)] + exact hb_init 3 h3lt + have h_v3_11 : (v3.elements.val[11]!).val.natAbs ≤ bnd := by + rw [h_v3_unc 11 h11lt (by decide) (by decide), + h_v2_unc 11 h11lt (by decide) (by decide), + h_v1_unc 11 h11lt (by decide) (by decide)] + exact hb_init 11 h11lt + obtain ⟨v4, h_v4_eq, h_v4_unc, h_v4_i, h_v4_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v3 zeta 3#usize 11#usize bnd + h3lt h11lt (by decide) hz h_v3_3 h_v3_11 h_bnd29439) + -- Step 5: (4, 12). + have h_v4_4 : (v4.elements.val[4]!).val.natAbs ≤ bnd := by + rw [h_v4_unc 4 h4lt (by decide) (by decide), + h_v3_unc 4 h4lt (by decide) (by decide), + h_v2_unc 4 h4lt (by decide) (by decide), + h_v1_unc 4 h4lt (by decide) (by decide)] + exact hb_init 4 h4lt + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ bnd := by + rw [h_v4_unc 12 h12lt (by decide) (by decide), + h_v3_unc 12 h12lt (by decide) (by decide), + h_v2_unc 12 h12lt (by decide) (by decide), + h_v1_unc 12 h12lt (by decide) (by decide)] + exact hb_init 12 h12lt + obtain ⟨v5, h_v5_eq, h_v5_unc, h_v5_i, h_v5_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v4 zeta 4#usize 12#usize bnd + h4lt h12lt (by decide) hz h_v4_4 h_v4_12 h_bnd29439) + -- Step 6: (5, 13). + have h_v5_5 : (v5.elements.val[5]!).val.natAbs ≤ bnd := by + rw [h_v5_unc 5 h5lt (by decide) (by decide), + h_v4_unc 5 h5lt (by decide) (by decide), + h_v3_unc 5 h5lt (by decide) (by decide), + h_v2_unc 5 h5lt (by decide) (by decide), + h_v1_unc 5 h5lt (by decide) (by decide)] + exact hb_init 5 h5lt + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ bnd := by + rw [h_v5_unc 13 h13lt (by decide) (by decide), + h_v4_unc 13 h13lt (by decide) (by decide), + h_v3_unc 13 h13lt (by decide) (by decide), + h_v2_unc 13 h13lt (by decide) (by decide), + h_v1_unc 13 h13lt (by decide) (by decide)] + exact hb_init 13 h13lt + obtain ⟨v6, h_v6_eq, h_v6_unc, h_v6_i, h_v6_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v5 zeta 5#usize 13#usize bnd + h5lt h13lt (by decide) hz h_v5_5 h_v5_13 h_bnd29439) + -- Step 7: (6, 14). + have h_v6_6 : (v6.elements.val[6]!).val.natAbs ≤ bnd := by + rw [h_v6_unc 6 h6lt (by decide) (by decide), + h_v5_unc 6 h6lt (by decide) (by decide), + h_v4_unc 6 h6lt (by decide) (by decide), + h_v3_unc 6 h6lt (by decide) (by decide), + h_v2_unc 6 h6lt (by decide) (by decide), + h_v1_unc 6 h6lt (by decide) (by decide)] + exact hb_init 6 h6lt + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ bnd := by + rw [h_v6_unc 14 h14lt (by decide) (by decide), + h_v5_unc 14 h14lt (by decide) (by decide), + h_v4_unc 14 h14lt (by decide) (by decide), + h_v3_unc 14 h14lt (by decide) (by decide), + h_v2_unc 14 h14lt (by decide) (by decide), + h_v1_unc 14 h14lt (by decide) (by decide)] + exact hb_init 14 h14lt + obtain ⟨v7, h_v7_eq, h_v7_unc, h_v7_i, h_v7_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v6 zeta 6#usize 14#usize bnd + h6lt h14lt (by decide) hz h_v6_6 h_v6_14 h_bnd29439) + -- Step 8: (7, 15). + have h_v7_7 : (v7.elements.val[7]!).val.natAbs ≤ bnd := by + rw [h_v7_unc 7 h7lt (by decide) (by decide), + h_v6_unc 7 h7lt (by decide) (by decide), + h_v5_unc 7 h7lt (by decide) (by decide), + h_v4_unc 7 h7lt (by decide) (by decide), + h_v3_unc 7 h7lt (by decide) (by decide), + h_v2_unc 7 h7lt (by decide) (by decide), + h_v1_unc 7 h7lt (by decide) (by decide)] + exact hb_init 7 h7lt + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ bnd := by + rw [h_v7_unc 15 h15lt (by decide) (by decide), + h_v6_unc 15 h15lt (by decide) (by decide), + h_v5_unc 15 h15lt (by decide) (by decide), + h_v4_unc 15 h15lt (by decide) (by decide), + h_v3_unc 15 h15lt (by decide) (by decide), + h_v2_unc 15 h15lt (by decide) (by decide), + h_v1_unc 15 h15lt (by decide) (by decide)] + exact hb_init 15 h15lt + obtain ⟨v8, h_v8_eq, h_v8_unc, h_v8_i, h_v8_j⟩ := + triple_exists_ok_l2 (ntt_step_spec_bnd v7 zeta 7#usize 15#usize bnd + h7lt h15lt (by decide) hz h_v7_7 h_v7_15 h_bnd29439) + -- Compose into one `.ok v8` equation. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step vec zeta = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step + rw [h_v1_eq]; simp only [bind_tc_ok] + rw [h_v2_eq]; simp only [bind_tc_ok] + rw [h_v3_eq]; simp only [bind_tc_ok] + rw [h_v4_eq]; simp only [bind_tc_ok] + rw [h_v5_eq]; simp only [bind_tc_ok] + rw [h_v6_eq]; simp only [bind_tc_ok] + rw [h_v7_eq]; simp only [bind_tc_ok] + exact h_v8_eq + -- Close: per-lane case split. + apply triple_of_ok_l2 h_body + intro i hi + interval_cases i + -- Lane 0: step 1 i-lane. + · have h_eq : v8.elements.val[0]! = v1.elements.val[0]! := by + rw [h_v8_unc 0 h0lt (by decide) (by decide), + h_v7_unc 0 h0lt (by decide) (by decide), + h_v6_unc 0 h0lt (by decide) (by decide), + h_v5_unc 0 h0lt (by decide) (by decide), + h_v4_unc 0 h0lt (by decide) (by decide), + h_v3_unc 0 h0lt (by decide) (by decide), + h_v2_unc 0 h0lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_i + -- Lane 1: step 2 i-lane. + · have h_eq : v8.elements.val[1]! = v2.elements.val[1]! := by + rw [h_v8_unc 1 h1lt (by decide) (by decide), + h_v7_unc 1 h1lt (by decide) (by decide), + h_v6_unc 1 h1lt (by decide) (by decide), + h_v5_unc 1 h1lt (by decide) (by decide), + h_v4_unc 1 h1lt (by decide) (by decide), + h_v3_unc 1 h1lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_i + -- Lane 2: step 3 i-lane. + · have h_eq : v8.elements.val[2]! = v3.elements.val[2]! := by + rw [h_v8_unc 2 h2lt (by decide) (by decide), + h_v7_unc 2 h2lt (by decide) (by decide), + h_v6_unc 2 h2lt (by decide) (by decide), + h_v5_unc 2 h2lt (by decide) (by decide), + h_v4_unc 2 h2lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_i + -- Lane 3: step 4 i-lane. + · have h_eq : v8.elements.val[3]! = v4.elements.val[3]! := by + rw [h_v8_unc 3 h3lt (by decide) (by decide), + h_v7_unc 3 h3lt (by decide) (by decide), + h_v6_unc 3 h3lt (by decide) (by decide), + h_v5_unc 3 h3lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_i + -- Lane 4: step 5 i-lane. + · have h_eq : v8.elements.val[4]! = v5.elements.val[4]! := by + rw [h_v8_unc 4 h4lt (by decide) (by decide), + h_v7_unc 4 h4lt (by decide) (by decide), + h_v6_unc 4 h4lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_i + -- Lane 5: step 6 i-lane. + · have h_eq : v8.elements.val[5]! = v6.elements.val[5]! := by + rw [h_v8_unc 5 h5lt (by decide) (by decide), + h_v7_unc 5 h5lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_i + -- Lane 6: step 7 i-lane. + · have h_eq : v8.elements.val[6]! = v7.elements.val[6]! := by + rw [h_v8_unc 6 h6lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_i + -- Lane 7: step 8 i-lane. + · exact h_v8_i + -- Lane 8: step 1 j-lane. + · have h_eq : v8.elements.val[8]! = v1.elements.val[8]! := by + rw [h_v8_unc 8 h8lt (by decide) (by decide), + h_v7_unc 8 h8lt (by decide) (by decide), + h_v6_unc 8 h8lt (by decide) (by decide), + h_v5_unc 8 h8lt (by decide) (by decide), + h_v4_unc 8 h8lt (by decide) (by decide), + h_v3_unc 8 h8lt (by decide) (by decide), + h_v2_unc 8 h8lt (by decide) (by decide)] + rw [h_eq]; exact h_v1_j + -- Lane 9: step 2 j-lane. + · have h_eq : v8.elements.val[9]! = v2.elements.val[9]! := by + rw [h_v8_unc 9 h9lt (by decide) (by decide), + h_v7_unc 9 h9lt (by decide) (by decide), + h_v6_unc 9 h9lt (by decide) (by decide), + h_v5_unc 9 h9lt (by decide) (by decide), + h_v4_unc 9 h9lt (by decide) (by decide), + h_v3_unc 9 h9lt (by decide) (by decide)] + rw [h_eq]; exact h_v2_j + -- Lane 10: step 3 j-lane. + · have h_eq : v8.elements.val[10]! = v3.elements.val[10]! := by + rw [h_v8_unc 10 h10lt (by decide) (by decide), + h_v7_unc 10 h10lt (by decide) (by decide), + h_v6_unc 10 h10lt (by decide) (by decide), + h_v5_unc 10 h10lt (by decide) (by decide), + h_v4_unc 10 h10lt (by decide) (by decide)] + rw [h_eq]; exact h_v3_j + -- Lane 11: step 4 j-lane. + · have h_eq : v8.elements.val[11]! = v4.elements.val[11]! := by + rw [h_v8_unc 11 h11lt (by decide) (by decide), + h_v7_unc 11 h11lt (by decide) (by decide), + h_v6_unc 11 h11lt (by decide) (by decide), + h_v5_unc 11 h11lt (by decide) (by decide)] + rw [h_eq]; exact h_v4_j + -- Lane 12: step 5 j-lane. + · have h_eq : v8.elements.val[12]! = v5.elements.val[12]! := by + rw [h_v8_unc 12 h12lt (by decide) (by decide), + h_v7_unc 12 h12lt (by decide) (by decide), + h_v6_unc 12 h12lt (by decide) (by decide)] + rw [h_eq]; exact h_v5_j + -- Lane 13: step 6 j-lane. + · have h_eq : v8.elements.val[13]! = v6.elements.val[13]! := by + rw [h_v8_unc 13 h13lt (by decide) (by decide), + h_v7_unc 13 h13lt (by decide) (by decide)] + rw [h_eq]; exact h_v6_j + -- Lane 14: step 7 j-lane. + · have h_eq : v8.elements.val[14]! = v7.elements.val[14]! := by + rw [h_v8_unc 14 h14lt (by decide) (by decide)] + rw [h_eq]; exact h_v7_j + -- Lane 15: step 8 j-lane. + · exact h_v8_j + +/-! ## L2.5 — `inv_ntt_step_spec` (Gentleman–Sande inverse butterfly) + + The impl is a straight chain of 9 binds: + 1. read `vec[j]` (= `i1`) + 2. read `vec[i]` (= `i2`) + 3. `a_minus_b = vec[j] - vec[i]` (wrapping_sub i1 i2) + 4. `a_plus_b = vec[j] + vec[i]` (wrapping_add i1 i2) + 5. `o0 = barrett_reduce_element a_plus_b` (L0.2) + 6. classify ζ (identity) + 7. `o1 = montgomery_multiply_fe_by_fer a_minus_b ζ` (L0.4) + 8. write `vec[i] := o0` (barrett-reduced sum) + 9. write `vec[j] := o1` (mont-mul diff by ζ) + + With pre `|vec[i]|, |vec[j]| ≤ B·3328` and `B ≤ 4`: + `|a_plus_b|, |a_minus_b| ≤ 2·B·3328 ≤ 8·3328 = 26624 ≤ 28296` (L0.2 pre). + L0.2 post: `|o0| ≤ 3328`. + L0.4 post: `|o1| ≤ 3328`. + So both touched lanes end ≤ 3328 (tight). Coarsenings (≤ 2·3328, ≤ 4·3328) + are derived at the layer level. + + The shape of the parameterized post mirrors `ntt_step_spec_B`: explicit + unchanged-lane equality + two tight bounds on the touched lanes. +-/ + +/-- Tight parameterized inverse butterfly Triple. `B` is the input lane bound; + output bound on touched lanes is the tight `3328`. -/ +@[spec] +theorem inv_ntt_step_spec_B + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i j : Std.Usize) (B : Nat) + (h_i : i.val < 16) (h_j : j.val < 16) (h_ne : i.val ≠ j.val) + (h_zeta : zeta.val.natAbs ≤ 1664) + (h_a : (vec.elements.val[i.val]!).val.natAbs ≤ B * 3328) + (h_b : (vec.elements.val[j.val]!).val.natAbs ≤ B * 3328) + (h_B : B ≤ 4) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step vec zeta i j + ⦃ ⇓ r => ⌜ (∀ k : Nat, k < 16 → k ≠ i.val → k ≠ j.val → + r.elements.val[k]! = vec.elements.val[k]!) + ∧ (r.elements.val[i.val]!).val.natAbs ≤ 3328 + ∧ (r.elements.val[j.val]!).val.natAbs ≤ 3328 ⌝ ⦄ := by + have h_vec_len : vec.elements.length = 16 := PortableVector_elements_length vec + -- Step 1: read vec[j] (= i1 in impl). + have h_idx_j : + Aeneas.Std.Array.index_usize vec.elements j = .ok (vec.elements.val[j.val]!) := + array_index_usize_ok_eq vec.elements j (by rw [h_vec_len]; exact h_j) + -- Step 2: read vec[i] (= i2 in impl). + have h_idx_i : + Aeneas.Std.Array.index_usize vec.elements i = .ok (vec.elements.val[i.val]!) := + array_index_usize_ok_eq vec.elements i (by rw [h_vec_len]; exact h_i) + set a : Std.I16 := vec.elements.val[i.val]! with ha_def + set b : Std.I16 := vec.elements.val[j.val]! with hb_def + -- Step 3: wrapping_sub b a (= a_minus_b). + have h_sub_eq : CoreModels.core.num.I16.wrapping_sub b a = .ok (Std.I16.wrapping_sub b a) := + cm_wrapping_sub_ok_eq b a + -- Step 4: wrapping_add b a (= a_plus_b). + have h_add_eq : CoreModels.core.num.I16.wrapping_add b a = .ok (Std.I16.wrapping_add b a) := + cm_wrapping_add_ok_eq b a + set a_minus_b : Std.I16 := Std.I16.wrapping_sub b a with hamb_def + set a_plus_b : Std.I16 := Std.I16.wrapping_add b a with hapb_def + -- Direct no-overflow proof for b ± a with |b|, |a| ≤ B·3328, B ≤ 4. + -- (`add_no_overflow_value_B` requires |t| ≤ 3328 on the second arg; we have + -- the stronger |a| ≤ B·3328, so we re-derive in this generality.) + have h_add_eq_val : (Std.I16.wrapping_add b a).val = b.val + a.val + ∧ (Std.I16.wrapping_add b a).val.natAbs ≤ 2 * B * 3328 := by + have h_sum_abs : ((b.val + a.val : Int)).natAbs ≤ 2 * B * 3328 := by + have h_tri : (b.val + a.val).natAbs ≤ b.val.natAbs + a.val.natAbs := + Int.natAbs_add_le _ _ + omega + have h_lb : -(2 ^ 15 : Int) ≤ b.val + a.val := by + have h_bound : 2 * B * 3328 ≤ 8 * 3328 := by + have : 2 * B ≤ 8 := by omega + exact Nat.mul_le_mul_right _ this + omega + have h_ub : b.val + a.val < (2 ^ 15 : Int) := by + have h_bound : 2 * B * 3328 ≤ 8 * 3328 := by + have : 2 * B ≤ 8 := by omega + exact Nat.mul_le_mul_right _ this + omega + have h_bmod : Int.bmod (b.val + a.val) (2 ^ 16) = b.val + a.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_add_val_eq b a + exact ⟨by rw [h_val, h_bmod], by rw [h_val, h_bmod]; exact h_sum_abs⟩ + have h_sub_eq_val : (Std.I16.wrapping_sub b a).val = b.val - a.val + ∧ (Std.I16.wrapping_sub b a).val.natAbs ≤ 2 * B * 3328 := by + have h_diff_abs : ((b.val - a.val : Int)).natAbs ≤ 2 * B * 3328 := by + have h_neg_natAbs : (-a.val).natAbs = a.val.natAbs := Int.natAbs_neg _ + have h_eq : b.val - a.val = b.val + (-a.val) := by ring + rw [h_eq] + have h_tri : (b.val + (-a.val)).natAbs ≤ b.val.natAbs + (-a.val).natAbs := + Int.natAbs_add_le _ _ + rw [h_neg_natAbs] at h_tri + omega + have h_lb : -(2 ^ 15 : Int) ≤ b.val - a.val := by + have h_bound : 2 * B * 3328 ≤ 8 * 3328 := by + have : 2 * B ≤ 8 := by omega + exact Nat.mul_le_mul_right _ this + omega + have h_ub : b.val - a.val < (2 ^ 15 : Int) := by + have h_bound : 2 * B * 3328 ≤ 8 * 3328 := by + have : 2 * B ≤ 8 := by omega + exact Nat.mul_le_mul_right _ this + omega + have h_bmod : Int.bmod (b.val - a.val) (2 ^ 16) = b.val - a.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_sub_val_eq b a + exact ⟨by rw [h_val, h_bmod], by rw [h_val, h_bmod]; exact h_diff_abs⟩ + -- Bound on a_plus_b for L0.2: 2·B·3328 ≤ 8·3328 = 26624 ≤ 28296. + have h_apb_bd : a_plus_b.val.natAbs ≤ 28296 := by + rw [hapb_def] + have h_step : 2 * B * 3328 ≤ 8 * 3328 := by + have : 2 * B ≤ 8 := by omega + exact Nat.mul_le_mul_right _ this + have h := h_add_eq_val.2 + omega + -- Step 5: L0.2 barrett_reduce_element on a_plus_b. + obtain ⟨o0, h_o0_eq_ok, _h_o0_mod, h_o0_bd⟩ := + triple_exists_ok_l2 (barrett_reduce_element_spec a_plus_b + (Nat.le_trans h_apb_bd (by decide : 28296 ≤ 32767))) + -- Step 6: classify ζ. + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify zeta = .ok zeta := + classify_ok_eq zeta + -- Step 7: L0.4 montgomery_multiply on (a_minus_b, ζ). + obtain ⟨o1, h_o1_eq_ok, h_o1_bd, _h_o1_mod⟩ := + triple_exists_ok_l2 (montgomery_multiply_fe_by_fer_spec a_minus_b zeta h_zeta) + -- Step 8: update vec at index i with o0. + have h_upd_i : + Aeneas.Std.Array.update vec.elements i o0 + = .ok (vec.elements.set i o0) := + array_update_ok_eq vec.elements i o0 (by rw [h_vec_len]; exact h_i) + -- Step 9: update at index j with o1. + have h_upd_j : + Aeneas.Std.Array.update (vec.elements.set i o0) j o1 + = .ok ((vec.elements.set i o0).set j o1) := by + have h_len : (vec.elements.set i o0).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact array_update_ok_eq _ j o1 (by rw [h_len]; exact h_j) + -- Compose into one `.ok` equation. + set final_elements : Std.Array Std.I16 16#usize := + (vec.elements.set i o0).set j o1 with hfe_def + set final_vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { elements := final_elements } with hfv_def + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step vec zeta i j = .ok final_vec := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step + rw [h_idx_j]; simp only [bind_tc_ok] + rw [h_idx_i]; simp only [bind_tc_ok] + rw [h_sub_eq]; simp only [bind_tc_ok] + rw [h_add_eq]; simp only [bind_tc_ok] + rw [h_o0_eq_ok]; simp only [bind_tc_ok] + rw [h_classify]; simp only [bind_tc_ok] + rw [h_o1_eq_ok]; simp only [bind_tc_ok] + rw [h_upd_i]; simp only [bind_tc_ok] + rw [h_upd_j]; simp only [bind_tc_ok]; rfl + -- Close the Triple. + apply triple_of_ok_l2 h_body + refine ⟨?_, ?_, ?_⟩ + · -- Unchanged lanes: k ≠ i.val and k ≠ j.val. + intro k hk_lt hk_ne_i hk_ne_j + have h_set_j_ne : + ((vec.elements.set i o0).set j o1)[k]! + = (vec.elements.set i o0)[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ j k _ (Ne.symm hk_ne_j) + have h_set_i_ne : + (vec.elements.set i o0)[k]! = (vec.elements)[k]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ i k _ (Ne.symm hk_ne_i) + show final_vec.elements.val[k]! = vec.elements.val[k]! + show ((vec.elements.set i o0).set j o1).val[k]! = vec.elements.val[k]! + rw [← Aeneas.Std.Array.getElem!_Nat_eq, ← Aeneas.Std.Array.getElem!_Nat_eq, + h_set_j_ne, h_set_i_ne] + · -- Bound on r.elements.val[i.val]!. Lane i gets o0 (barrett-reduced sum). + show (final_vec.elements.val[i.val]!).val.natAbs ≤ 3328 + show (((vec.elements.set i o0).set j o1).val[i.val]!).val.natAbs ≤ 3328 + have h_eq1 : + ((vec.elements.set i o0).set j o1).val[i.val]! + = ((vec.elements.set i o0).set j o1)[i.val]! := by + simp [Std.Array.getElem!_Nat_eq] + have h_ne_ji : j.val ≠ i.val := (Ne.symm h_ne) + have h_set_j_ne : + ((vec.elements.set i o0).set j o1)[i.val]! + = (vec.elements.set i o0)[i.val]! := + Aeneas.Std.Array.getElem!_Nat_set_ne _ j i.val _ h_ne_ji + have h_set_i_eq : + (vec.elements.set i o0)[i.val]! = o0 := + Aeneas.Std.Array.getElem!_Nat_set_eq _ i i.val _ ⟨rfl, by rw [h_vec_len]; exact h_i⟩ + rw [h_eq1, h_set_j_ne, h_set_i_eq] + exact h_o0_bd + · -- Bound on r.elements.val[j.val]!. Lane j gets o1 (mont-mul diff). + show (final_vec.elements.val[j.val]!).val.natAbs ≤ 3328 + show (((vec.elements.set i o0).set j o1).val[j.val]!).val.natAbs ≤ 3328 + have h_eq1 : + ((vec.elements.set i o0).set j o1).val[j.val]! + = ((vec.elements.set i o0).set j o1)[j.val]! := by + simp [Std.Array.getElem!_Nat_eq] + have h_set_j_eq : + ((vec.elements.set i o0).set j o1)[j.val]! = o1 := by + have h_len : (vec.elements.set i o0).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact Aeneas.Std.Array.getElem!_Nat_set_eq _ j j.val _ ⟨rfl, by rw [h_len]; exact h_j⟩ + rw [h_eq1, h_set_j_eq] + exact h_o1_bd + +/-- Named (B = 4) form matching the brief: `≤ 4·3328 → ≤ 4·3328` on all lanes. -/ +@[spec] +theorem inv_ntt_step_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i j : Std.Usize) + (h_i : i.val < 16) (h_j : j.val < 16) (h_ne : i.val ≠ j.val) + (h_zeta : zeta.val.natAbs ≤ 1664) + (hpre : ∀ k : Nat, k < 16 → (vec.elements.val[k]!).val.natAbs ≤ 4 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step vec zeta i j + ⦃ ⇓ r => ⌜ ∀ k : Nat, k < 16 → (r.elements.val[k]!).val.natAbs ≤ 4 * 3328 ⌝ ⦄ := by + obtain ⟨r, h_eq, h_unc, h_i_bd, h_j_bd⟩ := + triple_exists_ok_l2 + (inv_ntt_step_spec_B vec zeta i j 4 h_i h_j h_ne h_zeta (hpre i.val h_i) + (hpre j.val h_j) (by decide)) + apply triple_of_ok_l2 h_eq + intro k hk + by_cases h_ki : k = i.val + · subst h_ki + have h_3328_le : (3328 : Nat) ≤ 4 * 3328 := by decide + exact le_trans h_i_bd h_3328_le + · by_cases h_kj : k = j.val + · subst h_kj + have h_3328_le : (3328 : Nat) ≤ 4 * 3328 := by decide + exact le_trans h_j_bd h_3328_le + · rw [h_unc k hk h_ki h_kj] + exact hpre k hk + +/-! ## L2.6 — `inv_ntt_layer_1_step_spec` + + Eight disjoint inverse butterflies on pairs `(0,2)ζ0`, `(1,3)ζ0`, `(4,6)ζ1`, + `(5,7)ζ1`, `(8,10)ζ2`, `(9,11)ζ2`, `(12,14)ζ3`, `(13,15)ζ3`. Input lanes + `≤ 3328` (B=1); each touched lane lands at `≤ 3328`; pairs cover all 16 + lanes, so post is `≤ 3328 ≤ 2·3328`. +-/ + +@[spec] +theorem inv_ntt_layer_1_step_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 zeta2 zeta3 : Std.I16) + (hz0 : zeta0.val.natAbs ≤ 1664) (hz1 : zeta1.val.natAbs ≤ 1664) + (hz2 : zeta2.val.natAbs ≤ 1664) (hz3 : zeta3.val.natAbs ≤ 1664) + (hpre : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_1_step + vec zeta0 zeta1 zeta2 zeta3 + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ 2 * 3328 ⌝ ⦄ := by + have h0lt : (0#usize : Std.Usize).val < 16 := by decide + have h1lt : (1#usize : Std.Usize).val < 16 := by decide + have h2lt : (2#usize : Std.Usize).val < 16 := by decide + have h3lt : (3#usize : Std.Usize).val < 16 := by decide + have h4lt : (4#usize : Std.Usize).val < 16 := by decide + have h5lt : (5#usize : Std.Usize).val < 16 := by decide + have h6lt : (6#usize : Std.Usize).val < 16 := by decide + have h7lt : (7#usize : Std.Usize).val < 16 := by decide + have h8lt : (8#usize : Std.Usize).val < 16 := by decide + have h9lt : (9#usize : Std.Usize).val < 16 := by decide + have h10lt : (10#usize : Std.Usize).val < 16 := by decide + have h11lt : (11#usize : Std.Usize).val < 16 := by decide + have h12lt : (12#usize : Std.Usize).val < 16 := by decide + have h13lt : (13#usize : Std.Usize).val < 16 := by decide + have h14lt : (14#usize : Std.Usize).val < 16 := by decide + have h15lt : (15#usize : Std.Usize).val < 16 := by decide + have hb_init : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 1 * 3328 := by + intro i hi; have := hpre i hi; omega + -- Step 1: (0, 2) ζ0. + obtain ⟨v1, h_v1_eq, h_v1_unc, h_v1_i, h_v1_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B vec zeta0 0#usize 2#usize 1 + h0lt h2lt (by decide) hz0 (hb_init 0 h0lt) (hb_init 2 h2lt) (by decide)) + -- Step 2: (1, 3) ζ0. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v1_unc 1 h1lt (by decide) (by decide)]; exact hb_init 1 h1lt + have h_v1_3 : (v1.elements.val[3]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v1_unc 3 h3lt (by decide) (by decide)]; exact hb_init 3 h3lt + obtain ⟨v2, h_v2_eq, h_v2_unc, h_v2_i, h_v2_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v1 zeta0 1#usize 3#usize 1 + h1lt h3lt (by decide) hz0 h_v1_1 h_v1_3 (by decide)) + -- Step 3: (4, 6) ζ1. + have h_v2_4 : (v2.elements.val[4]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v2_unc 4 h4lt (by decide) (by decide), h_v1_unc 4 h4lt (by decide) (by decide)] + exact hb_init 4 h4lt + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v2_unc 6 h6lt (by decide) (by decide), h_v1_unc 6 h6lt (by decide) (by decide)] + exact hb_init 6 h6lt + obtain ⟨v3, h_v3_eq, h_v3_unc, h_v3_i, h_v3_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v2 zeta1 4#usize 6#usize 1 + h4lt h6lt (by decide) hz1 h_v2_4 h_v2_6 (by decide)) + -- Step 4: (5, 7) ζ1. + have h_v3_5 : (v3.elements.val[5]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v3_unc 5 h5lt (by decide) (by decide), + h_v2_unc 5 h5lt (by decide) (by decide), + h_v1_unc 5 h5lt (by decide) (by decide)] + exact hb_init 5 h5lt + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v3_unc 7 h7lt (by decide) (by decide), + h_v2_unc 7 h7lt (by decide) (by decide), + h_v1_unc 7 h7lt (by decide) (by decide)] + exact hb_init 7 h7lt + obtain ⟨v4, h_v4_eq, h_v4_unc, h_v4_i, h_v4_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v3 zeta1 5#usize 7#usize 1 + h5lt h7lt (by decide) hz1 h_v3_5 h_v3_7 (by decide)) + -- Step 5: (8, 10) ζ2. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v4_unc 8 h8lt (by decide) (by decide), + h_v3_unc 8 h8lt (by decide) (by decide), + h_v2_unc 8 h8lt (by decide) (by decide), + h_v1_unc 8 h8lt (by decide) (by decide)] + exact hb_init 8 h8lt + have h_v4_10 : (v4.elements.val[10]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v4_unc 10 h10lt (by decide) (by decide), + h_v3_unc 10 h10lt (by decide) (by decide), + h_v2_unc 10 h10lt (by decide) (by decide), + h_v1_unc 10 h10lt (by decide) (by decide)] + exact hb_init 10 h10lt + obtain ⟨v5, h_v5_eq, h_v5_unc, h_v5_i, h_v5_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v4 zeta2 8#usize 10#usize 1 + h8lt h10lt (by decide) hz2 h_v4_8 h_v4_10 (by decide)) + -- Step 6: (9, 11) ζ2. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v5_unc 9 h9lt (by decide) (by decide), + h_v4_unc 9 h9lt (by decide) (by decide), + h_v3_unc 9 h9lt (by decide) (by decide), + h_v2_unc 9 h9lt (by decide) (by decide), + h_v1_unc 9 h9lt (by decide) (by decide)] + exact hb_init 9 h9lt + have h_v5_11 : (v5.elements.val[11]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v5_unc 11 h11lt (by decide) (by decide), + h_v4_unc 11 h11lt (by decide) (by decide), + h_v3_unc 11 h11lt (by decide) (by decide), + h_v2_unc 11 h11lt (by decide) (by decide), + h_v1_unc 11 h11lt (by decide) (by decide)] + exact hb_init 11 h11lt + obtain ⟨v6, h_v6_eq, h_v6_unc, h_v6_i, h_v6_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v5 zeta2 9#usize 11#usize 1 + h9lt h11lt (by decide) hz2 h_v5_9 h_v5_11 (by decide)) + -- Step 7: (12, 14) ζ3. + have h_v6_12 : (v6.elements.val[12]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v6_unc 12 h12lt (by decide) (by decide), + h_v5_unc 12 h12lt (by decide) (by decide), + h_v4_unc 12 h12lt (by decide) (by decide), + h_v3_unc 12 h12lt (by decide) (by decide), + h_v2_unc 12 h12lt (by decide) (by decide), + h_v1_unc 12 h12lt (by decide) (by decide)] + exact hb_init 12 h12lt + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v6_unc 14 h14lt (by decide) (by decide), + h_v5_unc 14 h14lt (by decide) (by decide), + h_v4_unc 14 h14lt (by decide) (by decide), + h_v3_unc 14 h14lt (by decide) (by decide), + h_v2_unc 14 h14lt (by decide) (by decide), + h_v1_unc 14 h14lt (by decide) (by decide)] + exact hb_init 14 h14lt + obtain ⟨v7, h_v7_eq, h_v7_unc, h_v7_i, h_v7_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v6 zeta3 12#usize 14#usize 1 + h12lt h14lt (by decide) hz3 h_v6_12 h_v6_14 (by decide)) + -- Step 8: (13, 15) ζ3. + have h_v7_13 : (v7.elements.val[13]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v7_unc 13 h13lt (by decide) (by decide), + h_v6_unc 13 h13lt (by decide) (by decide), + h_v5_unc 13 h13lt (by decide) (by decide), + h_v4_unc 13 h13lt (by decide) (by decide), + h_v3_unc 13 h13lt (by decide) (by decide), + h_v2_unc 13 h13lt (by decide) (by decide), + h_v1_unc 13 h13lt (by decide) (by decide)] + exact hb_init 13 h13lt + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 1 * 3328 := by + rw [h_v7_unc 15 h15lt (by decide) (by decide), + h_v6_unc 15 h15lt (by decide) (by decide), + h_v5_unc 15 h15lt (by decide) (by decide), + h_v4_unc 15 h15lt (by decide) (by decide), + h_v3_unc 15 h15lt (by decide) (by decide), + h_v2_unc 15 h15lt (by decide) (by decide), + h_v1_unc 15 h15lt (by decide) (by decide)] + exact hb_init 15 h15lt + obtain ⟨v8, h_v8_eq, h_v8_unc, h_v8_i, h_v8_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v7 zeta3 13#usize 15#usize 1 + h13lt h15lt (by decide) hz3 h_v7_13 h_v7_15 (by decide)) + -- Compose into one `.ok v8` equation. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_1_step vec zeta0 zeta1 zeta2 zeta3 + = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_1_step + rw [h_v1_eq]; simp only [bind_tc_ok] + rw [h_v2_eq]; simp only [bind_tc_ok] + rw [h_v3_eq]; simp only [bind_tc_ok] + rw [h_v4_eq]; simp only [bind_tc_ok] + rw [h_v5_eq]; simp only [bind_tc_ok] + rw [h_v6_eq]; simp only [bind_tc_ok] + rw [h_v7_eq]; simp only [bind_tc_ok] + exact h_v8_eq + apply triple_of_ok_l2 h_body + intro i hi + -- Each lane is touched exactly once and ends at ≤ 3328 ≤ 2·3328. + have h_3328 : (3328 : Nat) ≤ 2 * 3328 := by decide + interval_cases i + -- Lane 0: step 1 i-lane. + · have h_eq : v8.elements.val[0]! = v1.elements.val[0]! := by + rw [h_v8_unc 0 h0lt (by decide) (by decide), + h_v7_unc 0 h0lt (by decide) (by decide), + h_v6_unc 0 h0lt (by decide) (by decide), + h_v5_unc 0 h0lt (by decide) (by decide), + h_v4_unc 0 h0lt (by decide) (by decide), + h_v3_unc 0 h0lt (by decide) (by decide), + h_v2_unc 0 h0lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v1_i h_3328 + -- Lane 1: step 2 i-lane. + · have h_eq : v8.elements.val[1]! = v2.elements.val[1]! := by + rw [h_v8_unc 1 h1lt (by decide) (by decide), + h_v7_unc 1 h1lt (by decide) (by decide), + h_v6_unc 1 h1lt (by decide) (by decide), + h_v5_unc 1 h1lt (by decide) (by decide), + h_v4_unc 1 h1lt (by decide) (by decide), + h_v3_unc 1 h1lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v2_i h_3328 + -- Lane 2: step 1 j-lane. + · have h_eq : v8.elements.val[2]! = v1.elements.val[2]! := by + rw [h_v8_unc 2 h2lt (by decide) (by decide), + h_v7_unc 2 h2lt (by decide) (by decide), + h_v6_unc 2 h2lt (by decide) (by decide), + h_v5_unc 2 h2lt (by decide) (by decide), + h_v4_unc 2 h2lt (by decide) (by decide), + h_v3_unc 2 h2lt (by decide) (by decide), + h_v2_unc 2 h2lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v1_j h_3328 + -- Lane 3: step 2 j-lane. + · have h_eq : v8.elements.val[3]! = v2.elements.val[3]! := by + rw [h_v8_unc 3 h3lt (by decide) (by decide), + h_v7_unc 3 h3lt (by decide) (by decide), + h_v6_unc 3 h3lt (by decide) (by decide), + h_v5_unc 3 h3lt (by decide) (by decide), + h_v4_unc 3 h3lt (by decide) (by decide), + h_v3_unc 3 h3lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v2_j h_3328 + -- Lane 4: step 3 i-lane. + · have h_eq : v8.elements.val[4]! = v3.elements.val[4]! := by + rw [h_v8_unc 4 h4lt (by decide) (by decide), + h_v7_unc 4 h4lt (by decide) (by decide), + h_v6_unc 4 h4lt (by decide) (by decide), + h_v5_unc 4 h4lt (by decide) (by decide), + h_v4_unc 4 h4lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v3_i h_3328 + -- Lane 5: step 4 i-lane. + · have h_eq : v8.elements.val[5]! = v4.elements.val[5]! := by + rw [h_v8_unc 5 h5lt (by decide) (by decide), + h_v7_unc 5 h5lt (by decide) (by decide), + h_v6_unc 5 h5lt (by decide) (by decide), + h_v5_unc 5 h5lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v4_i h_3328 + -- Lane 6: step 3 j-lane. + · have h_eq : v8.elements.val[6]! = v3.elements.val[6]! := by + rw [h_v8_unc 6 h6lt (by decide) (by decide), + h_v7_unc 6 h6lt (by decide) (by decide), + h_v6_unc 6 h6lt (by decide) (by decide), + h_v5_unc 6 h6lt (by decide) (by decide), + h_v4_unc 6 h6lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v3_j h_3328 + -- Lane 7: step 4 j-lane. + · have h_eq : v8.elements.val[7]! = v4.elements.val[7]! := by + rw [h_v8_unc 7 h7lt (by decide) (by decide), + h_v7_unc 7 h7lt (by decide) (by decide), + h_v6_unc 7 h7lt (by decide) (by decide), + h_v5_unc 7 h7lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v4_j h_3328 + -- Lane 8: step 5 i-lane. + · have h_eq : v8.elements.val[8]! = v5.elements.val[8]! := by + rw [h_v8_unc 8 h8lt (by decide) (by decide), + h_v7_unc 8 h8lt (by decide) (by decide), + h_v6_unc 8 h8lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v5_i h_3328 + -- Lane 9: step 6 i-lane. + · have h_eq : v8.elements.val[9]! = v6.elements.val[9]! := by + rw [h_v8_unc 9 h9lt (by decide) (by decide), + h_v7_unc 9 h9lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v6_i h_3328 + -- Lane 10: step 5 j-lane. + · have h_eq : v8.elements.val[10]! = v5.elements.val[10]! := by + rw [h_v8_unc 10 h10lt (by decide) (by decide), + h_v7_unc 10 h10lt (by decide) (by decide), + h_v6_unc 10 h10lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v5_j h_3328 + -- Lane 11: step 6 j-lane. + · have h_eq : v8.elements.val[11]! = v6.elements.val[11]! := by + rw [h_v8_unc 11 h11lt (by decide) (by decide), + h_v7_unc 11 h11lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v6_j h_3328 + -- Lane 12: step 7 i-lane. + · have h_eq : v8.elements.val[12]! = v7.elements.val[12]! := by + rw [h_v8_unc 12 h12lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v7_i h_3328 + -- Lane 13: step 8 i-lane. + · exact le_trans h_v8_i h_3328 + -- Lane 14: step 7 j-lane. + · have h_eq : v8.elements.val[14]! = v7.elements.val[14]! := by + rw [h_v8_unc 14 h14lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v7_j h_3328 + -- Lane 15: step 8 j-lane. + · exact le_trans h_v8_j h_3328 + +/-! ## L2.7a — `inv_ntt_layer_2_step_spec` + + Eight disjoint inverse butterflies on pairs `(0,4)ζ0`, `(1,5)ζ0`, `(2,6)ζ0`, + `(3,7)ζ0`, `(8,12)ζ1`, `(9,13)ζ1`, `(10,14)ζ1`, `(11,15)ζ1`. Input bound + `≤ 2·3328` (B=2); each touched lane lands at `≤ 3328 ≤ 4·3328`. +-/ + +@[spec] +theorem inv_ntt_layer_2_step_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta0 zeta1 : Std.I16) + (hz0 : zeta0.val.natAbs ≤ 1664) (hz1 : zeta1.val.natAbs ≤ 1664) + (hpre : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 2 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_2_step vec zeta0 zeta1 + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ 4 * 3328 ⌝ ⦄ := by + have h0lt : (0#usize : Std.Usize).val < 16 := by decide + have h1lt : (1#usize : Std.Usize).val < 16 := by decide + have h2lt : (2#usize : Std.Usize).val < 16 := by decide + have h3lt : (3#usize : Std.Usize).val < 16 := by decide + have h4lt : (4#usize : Std.Usize).val < 16 := by decide + have h5lt : (5#usize : Std.Usize).val < 16 := by decide + have h6lt : (6#usize : Std.Usize).val < 16 := by decide + have h7lt : (7#usize : Std.Usize).val < 16 := by decide + have h8lt : (8#usize : Std.Usize).val < 16 := by decide + have h9lt : (9#usize : Std.Usize).val < 16 := by decide + have h10lt : (10#usize : Std.Usize).val < 16 := by decide + have h11lt : (11#usize : Std.Usize).val < 16 := by decide + have h12lt : (12#usize : Std.Usize).val < 16 := by decide + have h13lt : (13#usize : Std.Usize).val < 16 := by decide + have h14lt : (14#usize : Std.Usize).val < 16 := by decide + have h15lt : (15#usize : Std.Usize).val < 16 := by decide + have hb_init : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 2 * 3328 := hpre + -- Step 1: (0, 4) ζ0. + obtain ⟨v1, h_v1_eq, h_v1_unc, h_v1_i, h_v1_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B vec zeta0 0#usize 4#usize 2 + h0lt h4lt (by decide) hz0 (hb_init 0 h0lt) (hb_init 4 h4lt) (by decide)) + -- Step 2: (1, 5) ζ0. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v1_unc 1 h1lt (by decide) (by decide)]; exact hb_init 1 h1lt + have h_v1_5 : (v1.elements.val[5]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v1_unc 5 h5lt (by decide) (by decide)]; exact hb_init 5 h5lt + obtain ⟨v2, h_v2_eq, h_v2_unc, h_v2_i, h_v2_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v1 zeta0 1#usize 5#usize 2 + h1lt h5lt (by decide) hz0 h_v1_1 h_v1_5 (by decide)) + -- Step 3: (2, 6) ζ0. + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v2_unc 2 h2lt (by decide) (by decide), h_v1_unc 2 h2lt (by decide) (by decide)] + exact hb_init 2 h2lt + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v2_unc 6 h6lt (by decide) (by decide), h_v1_unc 6 h6lt (by decide) (by decide)] + exact hb_init 6 h6lt + obtain ⟨v3, h_v3_eq, h_v3_unc, h_v3_i, h_v3_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v2 zeta0 2#usize 6#usize 2 + h2lt h6lt (by decide) hz0 h_v2_2 h_v2_6 (by decide)) + -- Step 4: (3, 7) ζ0. + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v3_unc 3 h3lt (by decide) (by decide), + h_v2_unc 3 h3lt (by decide) (by decide), + h_v1_unc 3 h3lt (by decide) (by decide)] + exact hb_init 3 h3lt + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v3_unc 7 h7lt (by decide) (by decide), + h_v2_unc 7 h7lt (by decide) (by decide), + h_v1_unc 7 h7lt (by decide) (by decide)] + exact hb_init 7 h7lt + obtain ⟨v4, h_v4_eq, h_v4_unc, h_v4_i, h_v4_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v3 zeta0 3#usize 7#usize 2 + h3lt h7lt (by decide) hz0 h_v3_3 h_v3_7 (by decide)) + -- Step 5: (8, 12) ζ1. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v4_unc 8 h8lt (by decide) (by decide), + h_v3_unc 8 h8lt (by decide) (by decide), + h_v2_unc 8 h8lt (by decide) (by decide), + h_v1_unc 8 h8lt (by decide) (by decide)] + exact hb_init 8 h8lt + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v4_unc 12 h12lt (by decide) (by decide), + h_v3_unc 12 h12lt (by decide) (by decide), + h_v2_unc 12 h12lt (by decide) (by decide), + h_v1_unc 12 h12lt (by decide) (by decide)] + exact hb_init 12 h12lt + obtain ⟨v5, h_v5_eq, h_v5_unc, h_v5_i, h_v5_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v4 zeta1 8#usize 12#usize 2 + h8lt h12lt (by decide) hz1 h_v4_8 h_v4_12 (by decide)) + -- Step 6: (9, 13) ζ1. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v5_unc 9 h9lt (by decide) (by decide), + h_v4_unc 9 h9lt (by decide) (by decide), + h_v3_unc 9 h9lt (by decide) (by decide), + h_v2_unc 9 h9lt (by decide) (by decide), + h_v1_unc 9 h9lt (by decide) (by decide)] + exact hb_init 9 h9lt + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v5_unc 13 h13lt (by decide) (by decide), + h_v4_unc 13 h13lt (by decide) (by decide), + h_v3_unc 13 h13lt (by decide) (by decide), + h_v2_unc 13 h13lt (by decide) (by decide), + h_v1_unc 13 h13lt (by decide) (by decide)] + exact hb_init 13 h13lt + obtain ⟨v6, h_v6_eq, h_v6_unc, h_v6_i, h_v6_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v5 zeta1 9#usize 13#usize 2 + h9lt h13lt (by decide) hz1 h_v5_9 h_v5_13 (by decide)) + -- Step 7: (10, 14) ζ1. + have h_v6_10 : (v6.elements.val[10]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v6_unc 10 h10lt (by decide) (by decide), + h_v5_unc 10 h10lt (by decide) (by decide), + h_v4_unc 10 h10lt (by decide) (by decide), + h_v3_unc 10 h10lt (by decide) (by decide), + h_v2_unc 10 h10lt (by decide) (by decide), + h_v1_unc 10 h10lt (by decide) (by decide)] + exact hb_init 10 h10lt + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v6_unc 14 h14lt (by decide) (by decide), + h_v5_unc 14 h14lt (by decide) (by decide), + h_v4_unc 14 h14lt (by decide) (by decide), + h_v3_unc 14 h14lt (by decide) (by decide), + h_v2_unc 14 h14lt (by decide) (by decide), + h_v1_unc 14 h14lt (by decide) (by decide)] + exact hb_init 14 h14lt + obtain ⟨v7, h_v7_eq, h_v7_unc, h_v7_i, h_v7_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v6 zeta1 10#usize 14#usize 2 + h10lt h14lt (by decide) hz1 h_v6_10 h_v6_14 (by decide)) + -- Step 8: (11, 15) ζ1. + have h_v7_11 : (v7.elements.val[11]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v7_unc 11 h11lt (by decide) (by decide), + h_v6_unc 11 h11lt (by decide) (by decide), + h_v5_unc 11 h11lt (by decide) (by decide), + h_v4_unc 11 h11lt (by decide) (by decide), + h_v3_unc 11 h11lt (by decide) (by decide), + h_v2_unc 11 h11lt (by decide) (by decide), + h_v1_unc 11 h11lt (by decide) (by decide)] + exact hb_init 11 h11lt + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v7_unc 15 h15lt (by decide) (by decide), + h_v6_unc 15 h15lt (by decide) (by decide), + h_v5_unc 15 h15lt (by decide) (by decide), + h_v4_unc 15 h15lt (by decide) (by decide), + h_v3_unc 15 h15lt (by decide) (by decide), + h_v2_unc 15 h15lt (by decide) (by decide), + h_v1_unc 15 h15lt (by decide) (by decide)] + exact hb_init 15 h15lt + obtain ⟨v8, h_v8_eq, h_v8_unc, h_v8_i, h_v8_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v7 zeta1 11#usize 15#usize 2 + h11lt h15lt (by decide) hz1 h_v7_11 h_v7_15 (by decide)) + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_2_step vec zeta0 zeta1 + = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_2_step + rw [h_v1_eq]; simp only [bind_tc_ok] + rw [h_v2_eq]; simp only [bind_tc_ok] + rw [h_v3_eq]; simp only [bind_tc_ok] + rw [h_v4_eq]; simp only [bind_tc_ok] + rw [h_v5_eq]; simp only [bind_tc_ok] + rw [h_v6_eq]; simp only [bind_tc_ok] + rw [h_v7_eq]; simp only [bind_tc_ok] + exact h_v8_eq + apply triple_of_ok_l2 h_body + intro i hi + have h_3328 : (3328 : Nat) ≤ 4 * 3328 := by decide + interval_cases i + -- Lane 0: step 1 i-lane. + · have h_eq : v8.elements.val[0]! = v1.elements.val[0]! := by + rw [h_v8_unc 0 h0lt (by decide) (by decide), + h_v7_unc 0 h0lt (by decide) (by decide), + h_v6_unc 0 h0lt (by decide) (by decide), + h_v5_unc 0 h0lt (by decide) (by decide), + h_v4_unc 0 h0lt (by decide) (by decide), + h_v3_unc 0 h0lt (by decide) (by decide), + h_v2_unc 0 h0lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v1_i h_3328 + -- Lane 1: step 2 i-lane. + · have h_eq : v8.elements.val[1]! = v2.elements.val[1]! := by + rw [h_v8_unc 1 h1lt (by decide) (by decide), + h_v7_unc 1 h1lt (by decide) (by decide), + h_v6_unc 1 h1lt (by decide) (by decide), + h_v5_unc 1 h1lt (by decide) (by decide), + h_v4_unc 1 h1lt (by decide) (by decide), + h_v3_unc 1 h1lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v2_i h_3328 + -- Lane 2: step 3 i-lane. + · have h_eq : v8.elements.val[2]! = v3.elements.val[2]! := by + rw [h_v8_unc 2 h2lt (by decide) (by decide), + h_v7_unc 2 h2lt (by decide) (by decide), + h_v6_unc 2 h2lt (by decide) (by decide), + h_v5_unc 2 h2lt (by decide) (by decide), + h_v4_unc 2 h2lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v3_i h_3328 + -- Lane 3: step 4 i-lane. + · have h_eq : v8.elements.val[3]! = v4.elements.val[3]! := by + rw [h_v8_unc 3 h3lt (by decide) (by decide), + h_v7_unc 3 h3lt (by decide) (by decide), + h_v6_unc 3 h3lt (by decide) (by decide), + h_v5_unc 3 h3lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v4_i h_3328 + -- Lane 4: step 1 j-lane. + · have h_eq : v8.elements.val[4]! = v1.elements.val[4]! := by + rw [h_v8_unc 4 h4lt (by decide) (by decide), + h_v7_unc 4 h4lt (by decide) (by decide), + h_v6_unc 4 h4lt (by decide) (by decide), + h_v5_unc 4 h4lt (by decide) (by decide), + h_v4_unc 4 h4lt (by decide) (by decide), + h_v3_unc 4 h4lt (by decide) (by decide), + h_v2_unc 4 h4lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v1_j h_3328 + -- Lane 5: step 2 j-lane. + · have h_eq : v8.elements.val[5]! = v2.elements.val[5]! := by + rw [h_v8_unc 5 h5lt (by decide) (by decide), + h_v7_unc 5 h5lt (by decide) (by decide), + h_v6_unc 5 h5lt (by decide) (by decide), + h_v5_unc 5 h5lt (by decide) (by decide), + h_v4_unc 5 h5lt (by decide) (by decide), + h_v3_unc 5 h5lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v2_j h_3328 + -- Lane 6: step 3 j-lane. + · have h_eq : v8.elements.val[6]! = v3.elements.val[6]! := by + rw [h_v8_unc 6 h6lt (by decide) (by decide), + h_v7_unc 6 h6lt (by decide) (by decide), + h_v6_unc 6 h6lt (by decide) (by decide), + h_v5_unc 6 h6lt (by decide) (by decide), + h_v4_unc 6 h6lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v3_j h_3328 + -- Lane 7: step 4 j-lane. + · have h_eq : v8.elements.val[7]! = v4.elements.val[7]! := by + rw [h_v8_unc 7 h7lt (by decide) (by decide), + h_v7_unc 7 h7lt (by decide) (by decide), + h_v6_unc 7 h7lt (by decide) (by decide), + h_v5_unc 7 h7lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v4_j h_3328 + -- Lane 8: step 5 i-lane. + · have h_eq : v8.elements.val[8]! = v5.elements.val[8]! := by + rw [h_v8_unc 8 h8lt (by decide) (by decide), + h_v7_unc 8 h8lt (by decide) (by decide), + h_v6_unc 8 h8lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v5_i h_3328 + -- Lane 9: step 6 i-lane. + · have h_eq : v8.elements.val[9]! = v6.elements.val[9]! := by + rw [h_v8_unc 9 h9lt (by decide) (by decide), + h_v7_unc 9 h9lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v6_i h_3328 + -- Lane 10: step 7 i-lane. + · have h_eq : v8.elements.val[10]! = v7.elements.val[10]! := by + rw [h_v8_unc 10 h10lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v7_i h_3328 + -- Lane 11: step 8 i-lane. + · exact le_trans h_v8_i h_3328 + -- Lane 12: step 5 j-lane. + · have h_eq : v8.elements.val[12]! = v5.elements.val[12]! := by + rw [h_v8_unc 12 h12lt (by decide) (by decide), + h_v7_unc 12 h12lt (by decide) (by decide), + h_v6_unc 12 h12lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v5_j h_3328 + -- Lane 13: step 6 j-lane. + · have h_eq : v8.elements.val[13]! = v6.elements.val[13]! := by + rw [h_v8_unc 13 h13lt (by decide) (by decide), + h_v7_unc 13 h13lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v6_j h_3328 + -- Lane 14: step 7 j-lane. + · have h_eq : v8.elements.val[14]! = v7.elements.val[14]! := by + rw [h_v8_unc 14 h14lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v7_j h_3328 + -- Lane 15: step 8 j-lane. + · exact le_trans h_v8_j h_3328 + +/-! ## L2.7b — `inv_ntt_layer_3_step_spec` + + Eight disjoint inverse butterflies on pairs `(0,8)`, `(1,9)`, …, `(7,15)` + all with the single ζ. Input bound `≤ 2·3328` (B=2); each touched lane + lands at `≤ 3328 ≤ 4·3328`. +-/ + +@[spec] +theorem inv_ntt_layer_3_step_spec + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (hz : zeta.val.natAbs ≤ 1664) + (hpre : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 2 * 3328) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_3_step vec zeta + ⦃ ⇓ r => ⌜ ∀ i : Nat, i < 16 → (r.elements.val[i]!).val.natAbs ≤ 4 * 3328 ⌝ ⦄ := by + have h0lt : (0#usize : Std.Usize).val < 16 := by decide + have h1lt : (1#usize : Std.Usize).val < 16 := by decide + have h2lt : (2#usize : Std.Usize).val < 16 := by decide + have h3lt : (3#usize : Std.Usize).val < 16 := by decide + have h4lt : (4#usize : Std.Usize).val < 16 := by decide + have h5lt : (5#usize : Std.Usize).val < 16 := by decide + have h6lt : (6#usize : Std.Usize).val < 16 := by decide + have h7lt : (7#usize : Std.Usize).val < 16 := by decide + have h8lt : (8#usize : Std.Usize).val < 16 := by decide + have h9lt : (9#usize : Std.Usize).val < 16 := by decide + have h10lt : (10#usize : Std.Usize).val < 16 := by decide + have h11lt : (11#usize : Std.Usize).val < 16 := by decide + have h12lt : (12#usize : Std.Usize).val < 16 := by decide + have h13lt : (13#usize : Std.Usize).val < 16 := by decide + have h14lt : (14#usize : Std.Usize).val < 16 := by decide + have h15lt : (15#usize : Std.Usize).val < 16 := by decide + have hb_init : ∀ i : Nat, i < 16 → (vec.elements.val[i]!).val.natAbs ≤ 2 * 3328 := hpre + -- Step 1: (0, 8). + obtain ⟨v1, h_v1_eq, h_v1_unc, h_v1_i, h_v1_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B vec zeta 0#usize 8#usize 2 + h0lt h8lt (by decide) hz (hb_init 0 h0lt) (hb_init 8 h8lt) (by decide)) + -- Step 2: (1, 9). + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v1_unc 1 h1lt (by decide) (by decide)]; exact hb_init 1 h1lt + have h_v1_9 : (v1.elements.val[9]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v1_unc 9 h9lt (by decide) (by decide)]; exact hb_init 9 h9lt + obtain ⟨v2, h_v2_eq, h_v2_unc, h_v2_i, h_v2_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v1 zeta 1#usize 9#usize 2 + h1lt h9lt (by decide) hz h_v1_1 h_v1_9 (by decide)) + -- Step 3: (2, 10). + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v2_unc 2 h2lt (by decide) (by decide), h_v1_unc 2 h2lt (by decide) (by decide)] + exact hb_init 2 h2lt + have h_v2_10 : (v2.elements.val[10]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v2_unc 10 h10lt (by decide) (by decide), h_v1_unc 10 h10lt (by decide) (by decide)] + exact hb_init 10 h10lt + obtain ⟨v3, h_v3_eq, h_v3_unc, h_v3_i, h_v3_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v2 zeta 2#usize 10#usize 2 + h2lt h10lt (by decide) hz h_v2_2 h_v2_10 (by decide)) + -- Step 4: (3, 11). + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v3_unc 3 h3lt (by decide) (by decide), + h_v2_unc 3 h3lt (by decide) (by decide), + h_v1_unc 3 h3lt (by decide) (by decide)] + exact hb_init 3 h3lt + have h_v3_11 : (v3.elements.val[11]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v3_unc 11 h11lt (by decide) (by decide), + h_v2_unc 11 h11lt (by decide) (by decide), + h_v1_unc 11 h11lt (by decide) (by decide)] + exact hb_init 11 h11lt + obtain ⟨v4, h_v4_eq, h_v4_unc, h_v4_i, h_v4_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v3 zeta 3#usize 11#usize 2 + h3lt h11lt (by decide) hz h_v3_3 h_v3_11 (by decide)) + -- Step 5: (4, 12). + have h_v4_4 : (v4.elements.val[4]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v4_unc 4 h4lt (by decide) (by decide), + h_v3_unc 4 h4lt (by decide) (by decide), + h_v2_unc 4 h4lt (by decide) (by decide), + h_v1_unc 4 h4lt (by decide) (by decide)] + exact hb_init 4 h4lt + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v4_unc 12 h12lt (by decide) (by decide), + h_v3_unc 12 h12lt (by decide) (by decide), + h_v2_unc 12 h12lt (by decide) (by decide), + h_v1_unc 12 h12lt (by decide) (by decide)] + exact hb_init 12 h12lt + obtain ⟨v5, h_v5_eq, h_v5_unc, h_v5_i, h_v5_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v4 zeta 4#usize 12#usize 2 + h4lt h12lt (by decide) hz h_v4_4 h_v4_12 (by decide)) + -- Step 6: (5, 13). + have h_v5_5 : (v5.elements.val[5]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v5_unc 5 h5lt (by decide) (by decide), + h_v4_unc 5 h5lt (by decide) (by decide), + h_v3_unc 5 h5lt (by decide) (by decide), + h_v2_unc 5 h5lt (by decide) (by decide), + h_v1_unc 5 h5lt (by decide) (by decide)] + exact hb_init 5 h5lt + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v5_unc 13 h13lt (by decide) (by decide), + h_v4_unc 13 h13lt (by decide) (by decide), + h_v3_unc 13 h13lt (by decide) (by decide), + h_v2_unc 13 h13lt (by decide) (by decide), + h_v1_unc 13 h13lt (by decide) (by decide)] + exact hb_init 13 h13lt + obtain ⟨v6, h_v6_eq, h_v6_unc, h_v6_i, h_v6_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v5 zeta 5#usize 13#usize 2 + h5lt h13lt (by decide) hz h_v5_5 h_v5_13 (by decide)) + -- Step 7: (6, 14). + have h_v6_6 : (v6.elements.val[6]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v6_unc 6 h6lt (by decide) (by decide), + h_v5_unc 6 h6lt (by decide) (by decide), + h_v4_unc 6 h6lt (by decide) (by decide), + h_v3_unc 6 h6lt (by decide) (by decide), + h_v2_unc 6 h6lt (by decide) (by decide), + h_v1_unc 6 h6lt (by decide) (by decide)] + exact hb_init 6 h6lt + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v6_unc 14 h14lt (by decide) (by decide), + h_v5_unc 14 h14lt (by decide) (by decide), + h_v4_unc 14 h14lt (by decide) (by decide), + h_v3_unc 14 h14lt (by decide) (by decide), + h_v2_unc 14 h14lt (by decide) (by decide), + h_v1_unc 14 h14lt (by decide) (by decide)] + exact hb_init 14 h14lt + obtain ⟨v7, h_v7_eq, h_v7_unc, h_v7_i, h_v7_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v6 zeta 6#usize 14#usize 2 + h6lt h14lt (by decide) hz h_v6_6 h_v6_14 (by decide)) + -- Step 8: (7, 15). + have h_v7_7 : (v7.elements.val[7]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v7_unc 7 h7lt (by decide) (by decide), + h_v6_unc 7 h7lt (by decide) (by decide), + h_v5_unc 7 h7lt (by decide) (by decide), + h_v4_unc 7 h7lt (by decide) (by decide), + h_v3_unc 7 h7lt (by decide) (by decide), + h_v2_unc 7 h7lt (by decide) (by decide), + h_v1_unc 7 h7lt (by decide) (by decide)] + exact hb_init 7 h7lt + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 2 * 3328 := by + rw [h_v7_unc 15 h15lt (by decide) (by decide), + h_v6_unc 15 h15lt (by decide) (by decide), + h_v5_unc 15 h15lt (by decide) (by decide), + h_v4_unc 15 h15lt (by decide) (by decide), + h_v3_unc 15 h15lt (by decide) (by decide), + h_v2_unc 15 h15lt (by decide) (by decide), + h_v1_unc 15 h15lt (by decide) (by decide)] + exact hb_init 15 h15lt + obtain ⟨v8, h_v8_eq, h_v8_unc, h_v8_i, h_v8_j⟩ := + triple_exists_ok_l2 (inv_ntt_step_spec_B v7 zeta 7#usize 15#usize 2 + h7lt h15lt (by decide) hz h_v7_7 h_v7_15 (by decide)) + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_3_step vec zeta = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_3_step + rw [h_v1_eq]; simp only [bind_tc_ok] + rw [h_v2_eq]; simp only [bind_tc_ok] + rw [h_v3_eq]; simp only [bind_tc_ok] + rw [h_v4_eq]; simp only [bind_tc_ok] + rw [h_v5_eq]; simp only [bind_tc_ok] + rw [h_v6_eq]; simp only [bind_tc_ok] + rw [h_v7_eq]; simp only [bind_tc_ok] + exact h_v8_eq + apply triple_of_ok_l2 h_body + intro i hi + have h_3328 : (3328 : Nat) ≤ 4 * 3328 := by decide + interval_cases i + -- Lane 0: step 1 i-lane. + · have h_eq : v8.elements.val[0]! = v1.elements.val[0]! := by + rw [h_v8_unc 0 h0lt (by decide) (by decide), + h_v7_unc 0 h0lt (by decide) (by decide), + h_v6_unc 0 h0lt (by decide) (by decide), + h_v5_unc 0 h0lt (by decide) (by decide), + h_v4_unc 0 h0lt (by decide) (by decide), + h_v3_unc 0 h0lt (by decide) (by decide), + h_v2_unc 0 h0lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v1_i h_3328 + -- Lane 1: step 2 i-lane. + · have h_eq : v8.elements.val[1]! = v2.elements.val[1]! := by + rw [h_v8_unc 1 h1lt (by decide) (by decide), + h_v7_unc 1 h1lt (by decide) (by decide), + h_v6_unc 1 h1lt (by decide) (by decide), + h_v5_unc 1 h1lt (by decide) (by decide), + h_v4_unc 1 h1lt (by decide) (by decide), + h_v3_unc 1 h1lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v2_i h_3328 + -- Lane 2: step 3 i-lane. + · have h_eq : v8.elements.val[2]! = v3.elements.val[2]! := by + rw [h_v8_unc 2 h2lt (by decide) (by decide), + h_v7_unc 2 h2lt (by decide) (by decide), + h_v6_unc 2 h2lt (by decide) (by decide), + h_v5_unc 2 h2lt (by decide) (by decide), + h_v4_unc 2 h2lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v3_i h_3328 + -- Lane 3: step 4 i-lane. + · have h_eq : v8.elements.val[3]! = v4.elements.val[3]! := by + rw [h_v8_unc 3 h3lt (by decide) (by decide), + h_v7_unc 3 h3lt (by decide) (by decide), + h_v6_unc 3 h3lt (by decide) (by decide), + h_v5_unc 3 h3lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v4_i h_3328 + -- Lane 4: step 5 i-lane. + · have h_eq : v8.elements.val[4]! = v5.elements.val[4]! := by + rw [h_v8_unc 4 h4lt (by decide) (by decide), + h_v7_unc 4 h4lt (by decide) (by decide), + h_v6_unc 4 h4lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v5_i h_3328 + -- Lane 5: step 6 i-lane. + · have h_eq : v8.elements.val[5]! = v6.elements.val[5]! := by + rw [h_v8_unc 5 h5lt (by decide) (by decide), + h_v7_unc 5 h5lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v6_i h_3328 + -- Lane 6: step 7 i-lane. + · have h_eq : v8.elements.val[6]! = v7.elements.val[6]! := by + rw [h_v8_unc 6 h6lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v7_i h_3328 + -- Lane 7: step 8 i-lane. + · exact le_trans h_v8_i h_3328 + -- Lane 8: step 1 j-lane. + · have h_eq : v8.elements.val[8]! = v1.elements.val[8]! := by + rw [h_v8_unc 8 h8lt (by decide) (by decide), + h_v7_unc 8 h8lt (by decide) (by decide), + h_v6_unc 8 h8lt (by decide) (by decide), + h_v5_unc 8 h8lt (by decide) (by decide), + h_v4_unc 8 h8lt (by decide) (by decide), + h_v3_unc 8 h8lt (by decide) (by decide), + h_v2_unc 8 h8lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v1_j h_3328 + -- Lane 9: step 2 j-lane. + · have h_eq : v8.elements.val[9]! = v2.elements.val[9]! := by + rw [h_v8_unc 9 h9lt (by decide) (by decide), + h_v7_unc 9 h9lt (by decide) (by decide), + h_v6_unc 9 h9lt (by decide) (by decide), + h_v5_unc 9 h9lt (by decide) (by decide), + h_v4_unc 9 h9lt (by decide) (by decide), + h_v3_unc 9 h9lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v2_j h_3328 + -- Lane 10: step 3 j-lane. + · have h_eq : v8.elements.val[10]! = v3.elements.val[10]! := by + rw [h_v8_unc 10 h10lt (by decide) (by decide), + h_v7_unc 10 h10lt (by decide) (by decide), + h_v6_unc 10 h10lt (by decide) (by decide), + h_v5_unc 10 h10lt (by decide) (by decide), + h_v4_unc 10 h10lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v3_j h_3328 + -- Lane 11: step 4 j-lane. + · have h_eq : v8.elements.val[11]! = v4.elements.val[11]! := by + rw [h_v8_unc 11 h11lt (by decide) (by decide), + h_v7_unc 11 h11lt (by decide) (by decide), + h_v6_unc 11 h11lt (by decide) (by decide), + h_v5_unc 11 h11lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v4_j h_3328 + -- Lane 12: step 5 j-lane. + · have h_eq : v8.elements.val[12]! = v5.elements.val[12]! := by + rw [h_v8_unc 12 h12lt (by decide) (by decide), + h_v7_unc 12 h12lt (by decide) (by decide), + h_v6_unc 12 h12lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v5_j h_3328 + -- Lane 13: step 6 j-lane. + · have h_eq : v8.elements.val[13]! = v6.elements.val[13]! := by + rw [h_v8_unc 13 h13lt (by decide) (by decide), + h_v7_unc 13 h13lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v6_j h_3328 + -- Lane 14: step 7 j-lane. + · have h_eq : v8.elements.val[14]! = v7.elements.val[14]! := by + rw [h_v8_unc 14 h14lt (by decide) (by decide)] + rw [h_eq]; exact le_trans h_v7_j h_3328 + -- Lane 15: step 8 j-lane. + · exact le_trans h_v8_j h_3328 + +end libcrux_iot_ml_kem.Vector.Portable.Ntt +/-! ### Extracted from FCTargets.lean (§vector_ntt). -/ + +namespace libcrux_iot_ml_kem.Vector.Portable.Ntt +open libcrux_iot_ml_kem.Spec.Lift libcrux_iot_ml_kem.Vector.Portable.Arithmetic.Element libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement +open CoreModels Aeneas Aeneas.Std Std.Do +open libcrux_iot_ml_kem.Spec + +/-! ## §L2 — NTT step ops (5 theorems). -/ + +/-! ### L2.1 — `ntt_step` private helpers. -/ + +/-- Reduction of `core.num.I16.wrapping_sub` to the underlying + Aeneas `Std.I16.wrapping_sub`. Mirror of L2's helper, scoped to FCTargets. -/ +theorem ntt_step_fc.cm_wrapping_sub_ok_eq (x y : Std.I16) : + CoreModels.core.num.I16.wrapping_sub x y = .ok (Std.I16.wrapping_sub x y) := by + unfold CoreModels.core.num.I16.wrapping_sub + unfold rust_primitives.arithmetic.wrapping_sub_i16 + rfl + +/-- Reduction of `core.num.I16.wrapping_add` to the underlying + Aeneas `Std.I16.wrapping_add`. Mirror of L2's helper. -/ +theorem ntt_step_fc.cm_wrapping_add_ok_eq (x y : Std.I16) : + CoreModels.core.num.I16.wrapping_add x y = .ok (Std.I16.wrapping_add x y) := by + unfold CoreModels.core.num.I16.wrapping_add + unfold rust_primitives.arithmetic.wrapping_add_i16 + rfl + +/-- Reduction of `classify` to identity. Mirror of L2's helper. -/ +theorem ntt_step_fc.classify_ok_eq {T : Type} (x : T) : + libcrux_secrets.traits.Classify.Blanket.classify x = .ok x := rfl + +/-- Under `|a.val| ≤ bnd`, `|t.val| ≤ 3328`, and `bnd ≤ 29439`, the I16-wrapped + sum `a + t` has `.val = a.val + t.val` (no overflow). Mirror of L2's + `add_no_overflow_value_bnd`, scoped to FCTargets — only the value + equation is exposed (bound conjunct dropped, not needed here). -/ +theorem ntt_step_fc.add_no_overflow_value (a t : Std.I16) (bnd : Nat) + (h_a : a.val.natAbs ≤ bnd) (h_t : t.val.natAbs ≤ 3328) (h_bnd : bnd ≤ 29439) : + (Std.I16.wrapping_add a t).val = a.val + t.val := by + have h_sum_abs : ((a.val + t.val : Int)).natAbs ≤ bnd + 3328 := by + have h_tri : (a.val + t.val).natAbs ≤ a.val.natAbs + t.val.natAbs := Int.natAbs_add_le _ _ + omega + have h_lb : -(2 ^ 15 : Int) ≤ a.val + t.val := by + have h_bound : bnd + 3328 ≤ 32767 := by omega + omega + have h_ub : a.val + t.val < (2 ^ 15 : Int) := by + have h_bound : bnd + 3328 ≤ 32767 := by omega + omega + have h_bmod : Int.bmod (a.val + t.val) (2 ^ 16) = a.val + t.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_add_val_eq a t + rw [h_val, h_bmod] + +/-- Diff variant of `add_no_overflow_value`. -/ +theorem ntt_step_fc.sub_no_overflow_value (a t : Std.I16) (bnd : Nat) + (h_a : a.val.natAbs ≤ bnd) (h_t : t.val.natAbs ≤ 3328) (h_bnd : bnd ≤ 29439) : + (Std.I16.wrapping_sub a t).val = a.val - t.val := by + have h_diff_abs : ((a.val - t.val : Int)).natAbs ≤ bnd + 3328 := by + have h_neg_natAbs : (-t.val).natAbs = t.val.natAbs := Int.natAbs_neg _ + have h_eq : a.val - t.val = a.val + (-t.val) := by ring + rw [h_eq] + have h_tri : (a.val + (-t.val)).natAbs ≤ a.val.natAbs + (-t.val).natAbs := + Int.natAbs_add_le _ _ + rw [h_neg_natAbs] at h_tri + omega + have h_lb : -(2 ^ 15 : Int) ≤ a.val - t.val := by + have h_bound : bnd + 3328 ≤ 32767 := by omega + omega + have h_ub : a.val - t.val < (2 ^ 15 : Int) := by + have h_bound : bnd + 3328 ≤ 32767 := by omega + omega + have h_bmod : Int.bmod (a.val - t.val) (2 ^ 16) = a.val - t.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_sub_val_eq a t + rw [h_val, h_bmod] + +/-- Helper: `(lift_fe_mont x).val.val = (i16_to_spec_fe_mont x).val`. -/ +theorem lift_fe_mont_val_val (x : Std.I16) : + (lift_fe_mont x).val.val = (i16_to_spec_fe_mont x).val := by + unfold lift_fe_mont feOfZMod + show (BitVec.ofNat 16 (i16_to_spec_fe_mont x).val).toNat = (i16_to_spec_fe_mont x).val + rw [BitVec.toNat_ofNat] + have h_lt : (i16_to_spec_fe_mont x).val < 2 ^ 16 := + Nat.lt_of_lt_of_le (ZMod.val_lt _) (by decide) + exact Nat.mod_eq_of_lt h_lt + +/-- Bridge lemma for the L0.4 Mont-domain output: from the legacy modq + form `r.val ≡ b.val * zeta.val * 169 (mod 3329)`, derive the FE-level + `lift_fe r = mul_pure (lift_fe b) (lift_fe_mont zeta)`. + + Algebra: both sides are canonical FEs (left by `Canonical_lift_fe`, + right by `Canonical_mul_pure`). Equality reduces (via the canonical + round-trip `feOfZMod_zmodOfFE_of_canonical`) to a `ZMod 3329` equation + on their `zmodOfFE`-projections, closed by the legacy modq cast + `modq_eq_cast_zmod` plus `ring`. -/ +theorem lift_fe_mul_pure_mont_eq + (b zeta r : Std.I16) + (h : libcrux_iot_ml_kem.Spec.ModularArith.modq_eq r.val (b.val * zeta.val * 169) 3329) : + lift_fe r + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe b) (lift_fe_mont zeta) := by + -- LHS: lift_fe r = feOfZMod ((r.val : Int) : ZMod 3329). + have h_lhs : lift_fe r = feOfZMod (((r.val : Int)) : ZMod 3329) := by + unfold lift_fe i16_to_spec_fe_plain + rfl + -- RHS: mul_pure is canonical; reduce via round-trip. + set s : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe b) (lift_fe_mont zeta) with hs_def + have h_canon : s.val.val < 3329 := by + have h_cs := libcrux_iot_ml_kem.Spec.Pure.Canonical_mul_pure + (lift_fe b) (lift_fe_mont zeta) + unfold libcrux_iot_ml_kem.Spec.Pure.Canonical at h_cs + have hq : hacspec_ml_kem.parameters.FIELD_MODULUS.val = 3329 := by + unfold hacspec_ml_kem.parameters.FIELD_MODULUS; rfl + rw [hq] at h_cs + exact h_cs + have h_round_trip : feOfZMod (zmodOfFE s) = s := + feOfZMod_zmodOfFE_of_canonical s h_canon + -- zmodOfFE s = (b.val * zeta.val * 169 : ZMod 3329). + have h_zmod_s : zmodOfFE s = (((b.val * zeta.val * 169 : Int)) : ZMod 3329) := by + unfold zmodOfFE + rw [mul_pure_val_eq] + rw [ZMod.natCast_mod] + push_cast + have h_lb : ((lift_fe b).val.val : ZMod 3329) = (((b.val : Int)) : ZMod 3329) := by + rw [lift_fe_val_val b]; exact ZMod.natCast_zmod_val _ + have h_lz : ((lift_fe_mont zeta).val.val : ZMod 3329) + = (((zeta.val : Int)) : ZMod 3329) * 169 := by + rw [lift_fe_mont_val_val zeta, ZMod.natCast_zmod_val] + rw [i16_to_spec_fe_mont_unfold] + rw [h_lb, h_lz] + ring + -- Cast the modq hypothesis to a ZMod equality. + have h_zmod_eq : (((r.val : Int)) : ZMod 3329) + = (((b.val * zeta.val * 169 : Int)) : ZMod 3329) := + modq_eq_cast_zmod _ _ h + rw [h_lhs, ← h_round_trip, h_zmod_s, h_zmod_eq] + +/-- L2.1 — `ntt_step`: per-pair butterfly. + + **Preconditions beyond locked statement**: + - `hne : i.val ≠ j.val` — mirrors L2 legacy `ntt_step_spec`. Without + this, the impl's two writes (`a[j] := a-t` then `a[i] := a+t`) at + the same index yield `a+t` while the spec would also yield `a+t` + (via `(a.set j (a-t)).set i (a+t)` with `i = j` → same), but the + lift-level reasoning bifurcates messily. Real callers in L2.2/3/4 + all use distinct `i, j`. + - `hvec : ∀ k, k < 16 → (vec.elements.val[k]!).val.natAbs ≤ 29439` — + ensures the I16 wrapping_{add,sub} at indices `i, j` do not overflow. + The bound `29439 = 32767 - 3328` is the tightest that keeps + `|vec[i] ± t| ≤ 32767` when `|t| ≤ 3328` (L0.4 output bound). + Universal form (not just at `i, j`) for callers' convenience — + they typically carry a per-lane bound. -/ +@[spec high] +theorem ntt_step_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i j : Std.Usize) + (hi : i.val < 16) (hj : j.val < 16) + (hne : i.val ≠ j.val) + (hzeta : zeta.val.natAbs ≤ 1664) + (hvec : ∀ k : Nat, k < 16 → + (vec.elements.val[k]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_ntt_step_pure (lift_chunk vec) (lift_fe_mont zeta) i j ⌝ ⦄ := by + -- Step 0: vector length facts. + have h_vec_len : vec.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length vec + have h_vec_val_len : vec.elements.val.length = 16 := h_vec_len + -- Step 1: read vec[j]. + have h_idx_j : + Aeneas.Std.Array.index_usize vec.elements j = .ok (vec.elements.val[j.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec.elements j + (by rw [h_vec_len]; exact hj) + -- Step 2: classify ζ. + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify zeta = .ok zeta := + ntt_step_fc.classify_ok_eq zeta + -- Step 3: L0.4 keystone on (vec[j], ζ). + set b : Std.I16 := vec.elements.val[j.val]! with hb_def + have h_b_bnd_29439 : b.val.natAbs ≤ 29439 := hvec j.val hj + have h_b_bnd : b.val.natAbs ≤ 32767 := by + have := h_b_bnd_29439 + omega + obtain ⟨t, h_t_eq_ok, h_t_bd, h_t_lift⟩ := + triple_exists_ok_fc (montgomery_multiply_fe_by_fer_fc b zeta h_b_bnd hzeta) + -- Recover the modq form via the legacy spec (needed for `lift_fe_mul_pure_mont_eq`). + obtain ⟨t', h_t'_eq, h_t'_bnd_tight, h_t_modq⟩ := + triple_exists_ok_fc + (libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_multiply_fe_by_fer_spec b zeta hzeta) + -- t' = t (same impl call, both `.ok`). + have h_tt' : t = t' := by + have : (Result.ok t : Result _) = Result.ok t' := by rw [← h_t_eq_ok, h_t'_eq] + cases this; rfl + -- Step 4: read vec[i]. + have h_idx_i : + Aeneas.Std.Array.index_usize vec.elements i = .ok (vec.elements.val[i.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec.elements i + (by rw [h_vec_len]; exact hi) + set a : Std.I16 := vec.elements.val[i.val]! with ha_def + have h_a_bnd : a.val.natAbs ≤ 29439 := hvec i.val hi + -- Step 5,6: wrapping_sub / wrapping_add. + have h_sub_eq : + CoreModels.core.num.I16.wrapping_sub a t = .ok (Std.I16.wrapping_sub a t) := + ntt_step_fc.cm_wrapping_sub_ok_eq a t + have h_add_eq : + CoreModels.core.num.I16.wrapping_add a t = .ok (Std.I16.wrapping_add a t) := + ntt_step_fc.cm_wrapping_add_ok_eq a t + set a_minus_t : Std.I16 := Std.I16.wrapping_sub a t with hamt_def + set a_plus_t : Std.I16 := Std.I16.wrapping_add a t with hapt_def + have h_t_bd' : t.val.natAbs ≤ 3328 := by + -- L0.4-FC's locked-post bound is ≤ 3328 + 1665 = 4993; the legacy + -- is the tighter ≤ 3328 (from `montgomery_multiply_fe_by_fer_spec`). + rw [h_tt']; exact h_t'_bnd_tight + have h_amt_val : a_minus_t.val = a.val - t.val := + ntt_step_fc.sub_no_overflow_value a t 29439 h_a_bnd h_t_bd' (by decide) + have h_apt_val : a_plus_t.val = a.val + t.val := + ntt_step_fc.add_no_overflow_value a t 29439 h_a_bnd h_t_bd' (by decide) + -- Step 7,8: writes. + have h_upd_j : + Aeneas.Std.Array.update vec.elements j a_minus_t + = .ok (vec.elements.set j a_minus_t) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq vec.elements j a_minus_t + (by rw [h_vec_len]; exact hj) + have h_upd_i : + Aeneas.Std.Array.update (vec.elements.set j a_minus_t) i a_plus_t + = .ok ((vec.elements.set j a_minus_t).set i a_plus_t) := by + have h_len : (vec.elements.set j a_minus_t).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq _ i a_plus_t + (by rw [h_len]; exact hi) + -- Compose: derive `.ok final_vec` form. + set final_elements : Std.Array Std.I16 16#usize := + (vec.elements.set j a_minus_t).set i a_plus_t with hfe_def + set final_vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { elements := final_elements } with hfv_def + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j = .ok final_vec := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_step + rw [h_idx_j]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_classify]; simp only [Aeneas.Std.bind_tc_ok] + rw [← h_tt'] at h_t'_eq + rw [h_t'_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_i]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_sub_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_add_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_j]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_i]; simp only [Aeneas.Std.bind_tc_ok]; rfl + apply triple_of_ok_fc h_body + -- Now: prove the FC chunk equation. + -- Set up the abbreviations. + set s_t_fe : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe b) (lift_fe_mont zeta) with hs_t_fe_def + set s_minus : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe a) s_t_fe with hs_minus_def + set s_plus : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe a) s_t_fe with hs_plus_def + -- Reduce both sides to underlying lists via Subtype.ext. + unfold lift_chunk Spec.chunk_ntt_step_pure + apply Subtype.ext + -- Both `lift_chunk` and `Spec.chunk_ntt_step_pure` produce Std.Array FE 16. + -- After Subtype.ext the goal is on `.val : List FE`. + -- Reduce: `(Std.Array.make 16 L _).val = L` and `Std.Array.set v i x .val = v.val.set i.val x`. + simp only [Std.Array.set_val_eq] + -- The Std.Array.make .val reduces by rfl (it's `⟨L, _⟩.val = L`). + -- And `.val[k]!` on a `Std.Array.make _ L _` equals `L[k]!`. + -- LHS: final_vec.elements.val.map lift_fe (final_vec.elements is set-set). + show ((vec.elements.val.set j.val a_minus_t).set i.val a_plus_t).map lift_fe + = ((vec.elements.val.map lift_fe).set j.val + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((vec.elements.val.map lift_fe)[i.val]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((vec.elements.val.map lift_fe)[j.val]!) (lift_fe_mont zeta)))).set i.val + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((vec.elements.val.map lift_fe)[i.val]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((vec.elements.val.map lift_fe)[j.val]!) (lift_fe_mont zeta))) + -- Bridge: `(vec.elements.val.map lift_fe)[k]! = lift_fe (vec.elements.val[k]!)` when k < 16. + have h_map_lift_at (k : Nat) (hk : k < 16) : + (vec.elements.val.map lift_fe)[k]! = lift_fe (vec.elements.val[k]!) := by + have hk_lhs : k < (vec.elements.val.map lift_fe).length := by + simp [List.length_map, h_vec_val_len]; exact hk + rw [getElem!_pos (vec.elements.val.map lift_fe) k hk_lhs] + rw [List.getElem_map] + have hk_vec : k < vec.elements.val.length := by rw [h_vec_val_len]; exact hk + rw [getElem!_pos vec.elements.val k hk_vec] + rw [h_map_lift_at i.val hi, h_map_lift_at j.val hj] + -- The RHS s_t_fe / s_plus / s_minus values match: + -- sub_pure (lift_fe a) (mul_pure (lift_fe b) (lift_fe_mont zeta)) = s_minus + -- add_pure (lift_fe a) (mul_pure (lift_fe b) (lift_fe_mont zeta)) = s_plus + change ((vec.elements.val.set j.val a_minus_t).set i.val a_plus_t).map lift_fe + = ((vec.elements.val.map lift_fe).set j.val s_minus).set i.val s_plus + -- Per-index proof. + apply List.ext_getElem + · simp [List.length_map, List.length_set] + · intro k hk1 hk2 + have hk : k < 16 := by + have hk' : k < (((vec.elements.val.set j.val a_minus_t).set i.val a_plus_t).map lift_fe).length := hk1 + simp [List.length_map, List.length_set, h_vec_val_len] at hk' + exact hk' + rw [List.getElem_map] + by_cases h_eq_i : k = i.val + · -- k = i.val: r[i] = a_plus_t, RHS = s_plus. + subst h_eq_i + rw [List.getElem_set_self] + rw [List.getElem_set_self] + -- Goal: lift_fe a_plus_t = s_plus + show lift_fe a_plus_t = s_plus + have h_step1 : + lift_fe a_plus_t + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe a) (lift_fe t) := + lift_fe_add_pure_eq a t a_plus_t h_apt_val + rw [h_step1] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe a) (lift_fe t) = s_plus + simp only [hs_plus_def, hs_t_fe_def] + congr 1 + rw [h_tt'] + exact lift_fe_mul_pure_mont_eq b zeta t' h_t_modq + · -- k ≠ i.val. + rw [List.getElem_set_ne (Ne.symm h_eq_i)] + rw [List.getElem_set_ne (Ne.symm h_eq_i)] + by_cases h_eq_j : k = j.val + · -- k = j.val. + subst h_eq_j + rw [List.getElem_set_self] + rw [List.getElem_set_self] + show lift_fe a_minus_t = s_minus + have h_step1 : + lift_fe a_minus_t + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe a) (lift_fe t) := + lift_fe_sub_pure_eq a t a_minus_t h_amt_val + rw [h_step1] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe a) (lift_fe t) = s_minus + simp only [hs_minus_def, hs_t_fe_def] + congr 1 + rw [h_tt'] + exact lift_fe_mul_pure_mont_eq b zeta t' h_t_modq + · -- k ≠ i.val, k ≠ j.val: r[k] = vec[k] under lift_fe. + rw [List.getElem_set_ne (Ne.symm h_eq_j)] + rw [List.getElem_set_ne (Ne.symm h_eq_j)] + rw [List.getElem_map] + +/-- Per-lane variant of `ntt_step_fc` for layer composition. Same body + as the keystone, but the precondition is split into the two lanes + actually read (`i`, `j`). This is needed for layer-N proofs where + after each ntt_step the touched lanes exceed the universal `≤ 29439` + bound; the pairs within a layer are disjoint, so only the + untouched-pair lanes need to satisfy `≤ 29439` at each step. + + Also exposes the per-lane output bound `≤ 32767` (i.e. all lanes + remain valid `I16`s), used to chain across steps. -/ +theorem ntt_step_pair_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i j : Std.Usize) + (hi : i.val < 16) (hj : j.val < 16) + (hne : i.val ≠ j.val) + (hzeta : zeta.val.natAbs ≤ 1664) + (h_a_bnd : (vec.elements.val[i.val]!).val.natAbs ≤ 29439) + (h_b_bnd : (vec.elements.val[j.val]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_ntt_step_pure (lift_chunk vec) (lift_fe_mont zeta) i j + ∧ (∀ k : Nat, k < 16 → k ≠ i.val → k ≠ j.val → + (r.elements.val[k]!) = (vec.elements.val[k]!)) + ∧ (r.elements.val[i.val]!).val.natAbs ≤ 32767 + ∧ (r.elements.val[j.val]!).val.natAbs ≤ 32767 ⌝ ⦄ := by + -- Step 0: vector length facts. + have h_vec_len : vec.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length vec + have h_vec_val_len : vec.elements.val.length = 16 := h_vec_len + -- Step 1: read vec[j]. + have h_idx_j : + Aeneas.Std.Array.index_usize vec.elements j = .ok (vec.elements.val[j.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec.elements j + (by rw [h_vec_len]; exact hj) + -- Step 2: classify ζ. + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify zeta = .ok zeta := + ntt_step_fc.classify_ok_eq zeta + -- Step 3: L0.4 keystone on (vec[j], ζ). + set b : Std.I16 := vec.elements.val[j.val]! with hb_def + have h_b_bnd_29439 : b.val.natAbs ≤ 29439 := h_b_bnd + have h_b_bnd_max : b.val.natAbs ≤ 32767 := by + have := h_b_bnd_29439; omega + obtain ⟨t, h_t_eq_ok, h_t_bd, h_t_lift⟩ := + triple_exists_ok_fc (montgomery_multiply_fe_by_fer_fc b zeta h_b_bnd_max hzeta) + obtain ⟨t', h_t'_eq, h_t'_bnd_tight, h_t_modq⟩ := + triple_exists_ok_fc + (libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_multiply_fe_by_fer_spec b zeta hzeta) + have h_tt' : t = t' := by + have : (Result.ok t : Result _) = Result.ok t' := by rw [← h_t_eq_ok, h_t'_eq] + cases this; rfl + -- Step 4: read vec[i]. + have h_idx_i : + Aeneas.Std.Array.index_usize vec.elements i = .ok (vec.elements.val[i.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec.elements i + (by rw [h_vec_len]; exact hi) + set a : Std.I16 := vec.elements.val[i.val]! with ha_def + have h_a_bnd_29439 : a.val.natAbs ≤ 29439 := h_a_bnd + -- Step 5,6: wrapping_sub / wrapping_add. + have h_sub_eq : + CoreModels.core.num.I16.wrapping_sub a t = .ok (Std.I16.wrapping_sub a t) := + ntt_step_fc.cm_wrapping_sub_ok_eq a t + have h_add_eq : + CoreModels.core.num.I16.wrapping_add a t = .ok (Std.I16.wrapping_add a t) := + ntt_step_fc.cm_wrapping_add_ok_eq a t + set a_minus_t : Std.I16 := Std.I16.wrapping_sub a t with hamt_def + set a_plus_t : Std.I16 := Std.I16.wrapping_add a t with hapt_def + have h_t_bd' : t.val.natAbs ≤ 3328 := by + rw [h_tt']; exact h_t'_bnd_tight + have h_amt_val : a_minus_t.val = a.val - t.val := + ntt_step_fc.sub_no_overflow_value a t 29439 h_a_bnd_29439 h_t_bd' (by decide) + have h_apt_val : a_plus_t.val = a.val + t.val := + ntt_step_fc.add_no_overflow_value a t 29439 h_a_bnd_29439 h_t_bd' (by decide) + -- Step 7,8: writes. + have h_upd_j : + Aeneas.Std.Array.update vec.elements j a_minus_t + = .ok (vec.elements.set j a_minus_t) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq vec.elements j a_minus_t + (by rw [h_vec_len]; exact hj) + have h_upd_i : + Aeneas.Std.Array.update (vec.elements.set j a_minus_t) i a_plus_t + = .ok ((vec.elements.set j a_minus_t).set i a_plus_t) := by + have h_len : (vec.elements.set j a_minus_t).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq _ i a_plus_t + (by rw [h_len]; exact hi) + -- Compose. + set final_elements : Std.Array Std.I16 16#usize := + (vec.elements.set j a_minus_t).set i a_plus_t with hfe_def + set final_vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { elements := final_elements } with hfv_def + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_step vec zeta i j = .ok final_vec := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_step + rw [h_idx_j]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_classify]; simp only [Aeneas.Std.bind_tc_ok] + rw [← h_tt'] at h_t'_eq + rw [h_t'_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_i]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_sub_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_add_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_j]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_i]; simp only [Aeneas.Std.bind_tc_ok]; rfl + apply triple_of_ok_fc h_body + -- Now: 4 conjuncts. + refine ⟨?_, ?_, ?_, ?_⟩ + · -- lift_chunk equation: identical to keystone proof. + set s_t_fe : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe b) (lift_fe_mont zeta) with hs_t_fe_def + set s_minus : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe a) s_t_fe with hs_minus_def + set s_plus : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe a) s_t_fe with hs_plus_def + unfold lift_chunk Spec.chunk_ntt_step_pure + apply Subtype.ext + simp only [Std.Array.set_val_eq] + show ((vec.elements.val.set j.val a_minus_t).set i.val a_plus_t).map lift_fe + = ((vec.elements.val.map lift_fe).set j.val + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((vec.elements.val.map lift_fe)[i.val]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((vec.elements.val.map lift_fe)[j.val]!) (lift_fe_mont zeta)))).set i.val + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((vec.elements.val.map lift_fe)[i.val]!) + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + ((vec.elements.val.map lift_fe)[j.val]!) (lift_fe_mont zeta))) + have h_map_lift_at (k : Nat) (hk : k < 16) : + (vec.elements.val.map lift_fe)[k]! = lift_fe (vec.elements.val[k]!) := by + have hk_lhs : k < (vec.elements.val.map lift_fe).length := by + simp [List.length_map, h_vec_val_len]; exact hk + rw [getElem!_pos (vec.elements.val.map lift_fe) k hk_lhs] + rw [List.getElem_map] + have hk_vec : k < vec.elements.val.length := by rw [h_vec_val_len]; exact hk + rw [getElem!_pos vec.elements.val k hk_vec] + rw [h_map_lift_at i.val hi, h_map_lift_at j.val hj] + change ((vec.elements.val.set j.val a_minus_t).set i.val a_plus_t).map lift_fe + = ((vec.elements.val.map lift_fe).set j.val s_minus).set i.val s_plus + apply List.ext_getElem + · simp [List.length_map, List.length_set] + · intro k hk1 hk2 + have hk : k < 16 := by + have hk' : k < (((vec.elements.val.set j.val a_minus_t).set i.val a_plus_t).map lift_fe).length := hk1 + simp [List.length_map, List.length_set, h_vec_val_len] at hk' + exact hk' + rw [List.getElem_map] + by_cases h_eq_i : k = i.val + · subst h_eq_i + rw [List.getElem_set_self] + rw [List.getElem_set_self] + show lift_fe a_plus_t = s_plus + have h_step1 : + lift_fe a_plus_t + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe a) (lift_fe t) := + lift_fe_add_pure_eq a t a_plus_t h_apt_val + rw [h_step1] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe a) (lift_fe t) = s_plus + simp only [hs_plus_def, hs_t_fe_def] + congr 1 + rw [h_tt'] + exact lift_fe_mul_pure_mont_eq b zeta t' h_t_modq + · rw [List.getElem_set_ne (Ne.symm h_eq_i)] + rw [List.getElem_set_ne (Ne.symm h_eq_i)] + by_cases h_eq_j : k = j.val + · subst h_eq_j + rw [List.getElem_set_self] + rw [List.getElem_set_self] + show lift_fe a_minus_t = s_minus + have h_step1 : + lift_fe a_minus_t + = libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe a) (lift_fe t) := + lift_fe_sub_pure_eq a t a_minus_t h_amt_val + rw [h_step1] + show libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe a) (lift_fe t) = s_minus + simp only [hs_minus_def, hs_t_fe_def] + congr 1 + rw [h_tt'] + exact lift_fe_mul_pure_mont_eq b zeta t' h_t_modq + · rw [List.getElem_set_ne (Ne.symm h_eq_j)] + rw [List.getElem_set_ne (Ne.symm h_eq_j)] + rw [List.getElem_map] + · -- Untouched-lane preservation: r[k] = vec[k] for k ≠ i, j. + intro k hk hki hkj + show ((vec.elements.set j a_minus_t).set i a_plus_t).val[k]! + = vec.elements.val[k]! + have h_set_val_eq : ((vec.elements.set j a_minus_t).set i a_plus_t).val + = (vec.elements.val.set j.val a_minus_t).set i.val a_plus_t := by + simp [Std.Array.set_val_eq] + rw [h_set_val_eq] + -- (list.set j _).set i _ at index k (k ≠ i, k ≠ j) = original list at k. + have hk_set_i : k < (vec.elements.val.set j.val a_minus_t).length := by + simp [List.length_set, h_vec_val_len]; exact hk + rw [getElem!_pos _ k (by simp [List.length_set, h_vec_val_len]; exact hk)] + rw [List.getElem_set_ne (Ne.symm hki)] + rw [List.getElem_set_ne (Ne.symm hkj)] + rw [getElem!_pos vec.elements.val k (by rw [h_vec_val_len]; exact hk)] + · -- Bound at i: r[i] = a_plus_t = a + t (no-overflow), |a| ≤ 29439, |t| ≤ 3328. + show ((vec.elements.set j a_minus_t).set i a_plus_t).val[i.val]!.val.natAbs ≤ 32767 + have h_set_val_eq : ((vec.elements.set j a_minus_t).set i a_plus_t).val + = (vec.elements.val.set j.val a_minus_t).set i.val a_plus_t := by + simp [Std.Array.set_val_eq] + rw [h_set_val_eq] + rw [getElem!_pos _ i.val (by simp [List.length_set, h_vec_val_len]; exact hi)] + rw [List.getElem_set_self] + -- a_plus_t.val = a.val + t.val, |a| ≤ 29439, |t| ≤ 3328 ⇒ |sum| ≤ 32767. + have h_sum_abs : ((a.val + t.val : Int)).natAbs ≤ 29439 + 3328 := by + have h_tri : (a.val + t.val).natAbs ≤ a.val.natAbs + t.val.natAbs := Int.natAbs_add_le _ _ + omega + rw [h_apt_val]; omega + · -- Bound at j: r[j] = a_minus_t = a - t (no-overflow), similar. + show ((vec.elements.set j a_minus_t).set i a_plus_t).val[j.val]!.val.natAbs ≤ 32767 + have h_set_val_eq : ((vec.elements.set j a_minus_t).set i a_plus_t).val + = (vec.elements.val.set j.val a_minus_t).set i.val a_plus_t := by + simp [Std.Array.set_val_eq] + rw [h_set_val_eq] + rw [getElem!_pos _ j.val (by simp [List.length_set, h_vec_val_len]; exact hj)] + rw [List.getElem_set_ne hne] + rw [List.getElem_set_self] + have h_diff_abs : ((a.val - t.val : Int)).natAbs ≤ 29439 + 3328 := by + have h_neg : (-t.val).natAbs = t.val.natAbs := Int.natAbs_neg _ + have h_eq : a.val - t.val = a.val + (-t.val) := by ring + rw [h_eq] + have h_tri : (a.val + (-t.val)).natAbs ≤ a.val.natAbs + (-t.val).natAbs := + Int.natAbs_add_le _ _ + rw [h_neg] at h_tri + omega + rw [h_amt_val]; omega + +/-- L2.2 — `ntt_layer_1_step`: 8 butterfly pairs (0,2)(1,3) with z0, + (4,6)(5,7) with z1, (8,10)(9,11) with z2, (12,14)(13,15) with z3. + + **Precondition adjustment** (beyond locked statement): + - `hvec : ∀ k < 16, |vec[k]| ≤ 29439` — same as layer_2/3 (disjoint pairs). -/ +@[spec] +theorem ntt_layer_1_step_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (z0 z1 z2 z3 : Std.I16) + (hz : z0.val.natAbs ≤ 1664 ∧ z1.val.natAbs ≤ 1664 + ∧ z2.val.natAbs ≤ 1664 ∧ z3.val.natAbs ≤ 1664) + (hvec : ∀ k : Nat, k < 16 → + (vec.elements.val[k]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step vec z0 z1 z2 z3 + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_ntt_layer_1_step_pure (lift_chunk vec) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3) ⌝ ⦄ := by + obtain ⟨hz0, hz1, hz2, hz3⟩ := hz + have hi0 : (0 : Nat) < 16 := by decide + have hi1 : (1 : Nat) < 16 := by decide + have hi2 : (2 : Nat) < 16 := by decide + have hi3 : (3 : Nat) < 16 := by decide + have hi4 : (4 : Nat) < 16 := by decide + have hi5 : (5 : Nat) < 16 := by decide + have hi6 : (6 : Nat) < 16 := by decide + have hi7 : (7 : Nat) < 16 := by decide + have hi8 : (8 : Nat) < 16 := by decide + have hi9 : (9 : Nat) < 16 := by decide + have hi10 : (10 : Nat) < 16 := by decide + have hi11 : (11 : Nat) < 16 := by decide + have hi12 : (12 : Nat) < 16 := by decide + have hi13 : (13 : Nat) < 16 := by decide + have hi14 : (14 : Nat) < 16 := by decide + have hi15 : (15 : Nat) < 16 := by decide + -- Step 1: ntt_step vec z0 0 2. + obtain ⟨v1, h_v1_eq, h_v1_lift, h_v1_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc vec z0 0#usize 2#usize hi0 hi2 + (by decide) hz0 (hvec 0 hi0) (hvec 2 hi2)) + -- Step 2: ntt_step v1 z0 1 3. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 29439 := by + rw [h_v1_unc 1 hi1 (by decide) (by decide)]; exact hvec 1 hi1 + have h_v1_3 : (v1.elements.val[3]!).val.natAbs ≤ 29439 := by + rw [h_v1_unc 3 hi3 (by decide) (by decide)]; exact hvec 3 hi3 + obtain ⟨v2, h_v2_eq, h_v2_lift, h_v2_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v1 z0 1#usize 3#usize hi1 hi3 + (by decide) hz0 h_v1_1 h_v1_3) + -- Step 3: ntt_step v2 z1 4 6. + have h_v2_4 : (v2.elements.val[4]!).val.natAbs ≤ 29439 := by + rw [h_v2_unc 4 hi4 (by decide) (by decide), + h_v1_unc 4 hi4 (by decide) (by decide)]; exact hvec 4 hi4 + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ 29439 := by + rw [h_v2_unc 6 hi6 (by decide) (by decide), + h_v1_unc 6 hi6 (by decide) (by decide)]; exact hvec 6 hi6 + obtain ⟨v3, h_v3_eq, h_v3_lift, h_v3_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v2 z1 4#usize 6#usize hi4 hi6 + (by decide) hz1 h_v2_4 h_v2_6) + -- Step 4: ntt_step v3 z1 5 7. + have h_v3_5 : (v3.elements.val[5]!).val.natAbs ≤ 29439 := by + rw [h_v3_unc 5 hi5 (by decide) (by decide), + h_v2_unc 5 hi5 (by decide) (by decide), + h_v1_unc 5 hi5 (by decide) (by decide)]; exact hvec 5 hi5 + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ 29439 := by + rw [h_v3_unc 7 hi7 (by decide) (by decide), + h_v2_unc 7 hi7 (by decide) (by decide), + h_v1_unc 7 hi7 (by decide) (by decide)]; exact hvec 7 hi7 + obtain ⟨v4, h_v4_eq, h_v4_lift, h_v4_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v3 z1 5#usize 7#usize hi5 hi7 + (by decide) hz1 h_v3_5 h_v3_7) + -- Step 5: ntt_step v4 z2 8 10. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ 29439 := by + rw [h_v4_unc 8 hi8 (by decide) (by decide), + h_v3_unc 8 hi8 (by decide) (by decide), + h_v2_unc 8 hi8 (by decide) (by decide), + h_v1_unc 8 hi8 (by decide) (by decide)]; exact hvec 8 hi8 + have h_v4_10 : (v4.elements.val[10]!).val.natAbs ≤ 29439 := by + rw [h_v4_unc 10 hi10 (by decide) (by decide), + h_v3_unc 10 hi10 (by decide) (by decide), + h_v2_unc 10 hi10 (by decide) (by decide), + h_v1_unc 10 hi10 (by decide) (by decide)]; exact hvec 10 hi10 + obtain ⟨v5, h_v5_eq, h_v5_lift, h_v5_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v4 z2 8#usize 10#usize hi8 hi10 + (by decide) hz2 h_v4_8 h_v4_10) + -- Step 6: ntt_step v5 z2 9 11. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ 29439 := by + rw [h_v5_unc 9 hi9 (by decide) (by decide), + h_v4_unc 9 hi9 (by decide) (by decide), + h_v3_unc 9 hi9 (by decide) (by decide), + h_v2_unc 9 hi9 (by decide) (by decide), + h_v1_unc 9 hi9 (by decide) (by decide)]; exact hvec 9 hi9 + have h_v5_11 : (v5.elements.val[11]!).val.natAbs ≤ 29439 := by + rw [h_v5_unc 11 hi11 (by decide) (by decide), + h_v4_unc 11 hi11 (by decide) (by decide), + h_v3_unc 11 hi11 (by decide) (by decide), + h_v2_unc 11 hi11 (by decide) (by decide), + h_v1_unc 11 hi11 (by decide) (by decide)]; exact hvec 11 hi11 + obtain ⟨v6, h_v6_eq, h_v6_lift, h_v6_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v5 z2 9#usize 11#usize hi9 hi11 + (by decide) hz2 h_v5_9 h_v5_11) + -- Step 7: ntt_step v6 z3 12 14. + have h_v6_12 : (v6.elements.val[12]!).val.natAbs ≤ 29439 := by + rw [h_v6_unc 12 hi12 (by decide) (by decide), + h_v5_unc 12 hi12 (by decide) (by decide), + h_v4_unc 12 hi12 (by decide) (by decide), + h_v3_unc 12 hi12 (by decide) (by decide), + h_v2_unc 12 hi12 (by decide) (by decide), + h_v1_unc 12 hi12 (by decide) (by decide)]; exact hvec 12 hi12 + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 29439 := by + rw [h_v6_unc 14 hi14 (by decide) (by decide), + h_v5_unc 14 hi14 (by decide) (by decide), + h_v4_unc 14 hi14 (by decide) (by decide), + h_v3_unc 14 hi14 (by decide) (by decide), + h_v2_unc 14 hi14 (by decide) (by decide), + h_v1_unc 14 hi14 (by decide) (by decide)]; exact hvec 14 hi14 + obtain ⟨v7, h_v7_eq, h_v7_lift, h_v7_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v6 z3 12#usize 14#usize hi12 hi14 + (by decide) hz3 h_v6_12 h_v6_14) + -- Step 8: ntt_step v7 z3 13 15. + have h_v7_13 : (v7.elements.val[13]!).val.natAbs ≤ 29439 := by + rw [h_v7_unc 13 hi13 (by decide) (by decide), + h_v6_unc 13 hi13 (by decide) (by decide), + h_v5_unc 13 hi13 (by decide) (by decide), + h_v4_unc 13 hi13 (by decide) (by decide), + h_v3_unc 13 hi13 (by decide) (by decide), + h_v2_unc 13 hi13 (by decide) (by decide), + h_v1_unc 13 hi13 (by decide) (by decide)]; exact hvec 13 hi13 + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 29439 := by + rw [h_v7_unc 15 hi15 (by decide) (by decide), + h_v6_unc 15 hi15 (by decide) (by decide), + h_v5_unc 15 hi15 (by decide) (by decide), + h_v4_unc 15 hi15 (by decide) (by decide), + h_v3_unc 15 hi15 (by decide) (by decide), + h_v2_unc 15 hi15 (by decide) (by decide), + h_v1_unc 15 hi15 (by decide) (by decide)]; exact hvec 15 hi15 + obtain ⟨v8, h_v8_eq, h_v8_lift, _, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v7 z3 13#usize 15#usize hi13 hi15 + (by decide) hz3 h_v7_13 h_v7_15) + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step vec z0 z1 z2 z3 + = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_1_step + rw [h_v1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v5_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v6_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v7_eq]; simp only [Aeneas.Std.bind_tc_ok] + exact h_v8_eq + apply triple_of_ok_fc h_body + unfold Spec.chunk_ntt_layer_1_step_pure + rw [h_v8_lift, h_v7_lift, h_v6_lift, h_v5_lift, h_v4_lift, h_v3_lift, h_v2_lift, h_v1_lift] + +/-- L2.3 — `ntt_layer_2_step`: 8 butterfly pairs (0,4)…(3,7) with z0 then + (8,12)…(11,15) with z1. + + **Precondition adjustment** (beyond locked statement): + - `hvec : ∀ k < 16, |vec[k]| ≤ 29439` — same as layer_3 (disjoint pairs). -/ +@[spec] +theorem ntt_layer_2_step_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (z0 z1 : Std.I16) + (hz : z0.val.natAbs ≤ 1664 ∧ z1.val.natAbs ≤ 1664) + (hvec : ∀ k : Nat, k < 16 → + (vec.elements.val[k]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step vec z0 z1 + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_ntt_layer_2_step_pure (lift_chunk vec) + (lift_fe_mont z0) (lift_fe_mont z1) ⌝ ⦄ := by + obtain ⟨hz0, hz1⟩ := hz + have hi0 : (0 : Nat) < 16 := by decide + have hi1 : (1 : Nat) < 16 := by decide + have hi2 : (2 : Nat) < 16 := by decide + have hi3 : (3 : Nat) < 16 := by decide + have hi4 : (4 : Nat) < 16 := by decide + have hi5 : (5 : Nat) < 16 := by decide + have hi6 : (6 : Nat) < 16 := by decide + have hi7 : (7 : Nat) < 16 := by decide + have hi8 : (8 : Nat) < 16 := by decide + have hi9 : (9 : Nat) < 16 := by decide + have hi10 : (10 : Nat) < 16 := by decide + have hi11 : (11 : Nat) < 16 := by decide + have hi12 : (12 : Nat) < 16 := by decide + have hi13 : (13 : Nat) < 16 := by decide + have hi14 : (14 : Nat) < 16 := by decide + have hi15 : (15 : Nat) < 16 := by decide + -- Step 1: ntt_step vec z0 0 4. + obtain ⟨v1, h_v1_eq, h_v1_lift, h_v1_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc vec z0 0#usize 4#usize hi0 hi4 + (by decide) hz0 (hvec 0 hi0) (hvec 4 hi4)) + -- Step 2: ntt_step v1 z0 1 5. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 29439 := by + rw [h_v1_unc 1 hi1 (by decide) (by decide)]; exact hvec 1 hi1 + have h_v1_5 : (v1.elements.val[5]!).val.natAbs ≤ 29439 := by + rw [h_v1_unc 5 hi5 (by decide) (by decide)]; exact hvec 5 hi5 + obtain ⟨v2, h_v2_eq, h_v2_lift, h_v2_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v1 z0 1#usize 5#usize hi1 hi5 + (by decide) hz0 h_v1_1 h_v1_5) + -- Step 3: ntt_step v2 z0 2 6. + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ 29439 := by + rw [h_v2_unc 2 hi2 (by decide) (by decide), + h_v1_unc 2 hi2 (by decide) (by decide)]; exact hvec 2 hi2 + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ 29439 := by + rw [h_v2_unc 6 hi6 (by decide) (by decide), + h_v1_unc 6 hi6 (by decide) (by decide)]; exact hvec 6 hi6 + obtain ⟨v3, h_v3_eq, h_v3_lift, h_v3_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v2 z0 2#usize 6#usize hi2 hi6 + (by decide) hz0 h_v2_2 h_v2_6) + -- Step 4: ntt_step v3 z0 3 7. + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ 29439 := by + rw [h_v3_unc 3 hi3 (by decide) (by decide), + h_v2_unc 3 hi3 (by decide) (by decide), + h_v1_unc 3 hi3 (by decide) (by decide)]; exact hvec 3 hi3 + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ 29439 := by + rw [h_v3_unc 7 hi7 (by decide) (by decide), + h_v2_unc 7 hi7 (by decide) (by decide), + h_v1_unc 7 hi7 (by decide) (by decide)]; exact hvec 7 hi7 + obtain ⟨v4, h_v4_eq, h_v4_lift, h_v4_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v3 z0 3#usize 7#usize hi3 hi7 + (by decide) hz0 h_v3_3 h_v3_7) + -- Step 5: ntt_step v4 z1 8 12. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ 29439 := by + rw [h_v4_unc 8 hi8 (by decide) (by decide), + h_v3_unc 8 hi8 (by decide) (by decide), + h_v2_unc 8 hi8 (by decide) (by decide), + h_v1_unc 8 hi8 (by decide) (by decide)]; exact hvec 8 hi8 + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ 29439 := by + rw [h_v4_unc 12 hi12 (by decide) (by decide), + h_v3_unc 12 hi12 (by decide) (by decide), + h_v2_unc 12 hi12 (by decide) (by decide), + h_v1_unc 12 hi12 (by decide) (by decide)]; exact hvec 12 hi12 + obtain ⟨v5, h_v5_eq, h_v5_lift, h_v5_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v4 z1 8#usize 12#usize hi8 hi12 + (by decide) hz1 h_v4_8 h_v4_12) + -- Step 6: ntt_step v5 z1 9 13. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ 29439 := by + rw [h_v5_unc 9 hi9 (by decide) (by decide), + h_v4_unc 9 hi9 (by decide) (by decide), + h_v3_unc 9 hi9 (by decide) (by decide), + h_v2_unc 9 hi9 (by decide) (by decide), + h_v1_unc 9 hi9 (by decide) (by decide)]; exact hvec 9 hi9 + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ 29439 := by + rw [h_v5_unc 13 hi13 (by decide) (by decide), + h_v4_unc 13 hi13 (by decide) (by decide), + h_v3_unc 13 hi13 (by decide) (by decide), + h_v2_unc 13 hi13 (by decide) (by decide), + h_v1_unc 13 hi13 (by decide) (by decide)]; exact hvec 13 hi13 + obtain ⟨v6, h_v6_eq, h_v6_lift, h_v6_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v5 z1 9#usize 13#usize hi9 hi13 + (by decide) hz1 h_v5_9 h_v5_13) + -- Step 7: ntt_step v6 z1 10 14. + have h_v6_10 : (v6.elements.val[10]!).val.natAbs ≤ 29439 := by + rw [h_v6_unc 10 hi10 (by decide) (by decide), + h_v5_unc 10 hi10 (by decide) (by decide), + h_v4_unc 10 hi10 (by decide) (by decide), + h_v3_unc 10 hi10 (by decide) (by decide), + h_v2_unc 10 hi10 (by decide) (by decide), + h_v1_unc 10 hi10 (by decide) (by decide)]; exact hvec 10 hi10 + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 29439 := by + rw [h_v6_unc 14 hi14 (by decide) (by decide), + h_v5_unc 14 hi14 (by decide) (by decide), + h_v4_unc 14 hi14 (by decide) (by decide), + h_v3_unc 14 hi14 (by decide) (by decide), + h_v2_unc 14 hi14 (by decide) (by decide), + h_v1_unc 14 hi14 (by decide) (by decide)]; exact hvec 14 hi14 + obtain ⟨v7, h_v7_eq, h_v7_lift, h_v7_unc, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v6 z1 10#usize 14#usize hi10 hi14 + (by decide) hz1 h_v6_10 h_v6_14) + -- Step 8: ntt_step v7 z1 11 15. + have h_v7_11 : (v7.elements.val[11]!).val.natAbs ≤ 29439 := by + rw [h_v7_unc 11 hi11 (by decide) (by decide), + h_v6_unc 11 hi11 (by decide) (by decide), + h_v5_unc 11 hi11 (by decide) (by decide), + h_v4_unc 11 hi11 (by decide) (by decide), + h_v3_unc 11 hi11 (by decide) (by decide), + h_v2_unc 11 hi11 (by decide) (by decide), + h_v1_unc 11 hi11 (by decide) (by decide)]; exact hvec 11 hi11 + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 29439 := by + rw [h_v7_unc 15 hi15 (by decide) (by decide), + h_v6_unc 15 hi15 (by decide) (by decide), + h_v5_unc 15 hi15 (by decide) (by decide), + h_v4_unc 15 hi15 (by decide) (by decide), + h_v3_unc 15 hi15 (by decide) (by decide), + h_v2_unc 15 hi15 (by decide) (by decide), + h_v1_unc 15 hi15 (by decide) (by decide)]; exact hvec 15 hi15 + obtain ⟨v8, h_v8_eq, h_v8_lift, _, _, _⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v7 z1 11#usize 15#usize hi11 hi15 + (by decide) hz1 h_v7_11 h_v7_15) + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step vec z0 z1 = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_2_step + rw [h_v1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v5_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v6_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v7_eq]; simp only [Aeneas.Std.bind_tc_ok] + exact h_v8_eq + apply triple_of_ok_fc h_body + unfold Spec.chunk_ntt_layer_2_step_pure + rw [h_v8_lift, h_v7_lift, h_v6_lift, h_v5_lift, h_v4_lift, h_v3_lift, h_v2_lift, h_v1_lift] + +/-- L2.4 — `ntt_layer_3_step`: 8 butterfly pairs (0,8)…(7,15) with one zeta. + + **Precondition adjustment** (beyond locked statement): + - `hvec : ∀ k < 16, |vec[k]| ≤ 29439` — chained through the 8 + ntt_step calls. Pairs are disjoint (each lane touched exactly + once), so the keystone's `≤ 29439` precondition holds at each + step on the unchanged lanes. -/ +@[spec] +theorem ntt_layer_3_step_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (z : Std.I16) (hz : z.val.natAbs ≤ 1664) + (hvec : ∀ k : Nat, k < 16 → + (vec.elements.val[k]!).val.natAbs ≤ 29439) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step vec z + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_ntt_layer_3_step_pure (lift_chunk vec) (lift_fe_mont z) ⌝ ⦄ := by + -- Initial-lane bounds (specialised from hvec). + have hi0 : (0 : Nat) < 16 := by decide + have hi1 : (1 : Nat) < 16 := by decide + have hi2 : (2 : Nat) < 16 := by decide + have hi3 : (3 : Nat) < 16 := by decide + have hi4 : (4 : Nat) < 16 := by decide + have hi5 : (5 : Nat) < 16 := by decide + have hi6 : (6 : Nat) < 16 := by decide + have hi7 : (7 : Nat) < 16 := by decide + have hi8 : (8 : Nat) < 16 := by decide + have hi9 : (9 : Nat) < 16 := by decide + have hi10 : (10 : Nat) < 16 := by decide + have hi11 : (11 : Nat) < 16 := by decide + have hi12 : (12 : Nat) < 16 := by decide + have hi13 : (13 : Nat) < 16 := by decide + have hi14 : (14 : Nat) < 16 := by decide + have hi15 : (15 : Nat) < 16 := by decide + -- Step 1: ntt_step vec z 0 8. + obtain ⟨v1, h_v1_eq, h_v1_lift, h_v1_unc, _h_v1_i_bd, _h_v1_j_bd⟩ := + triple_exists_ok_fc (ntt_step_pair_fc vec z 0#usize 8#usize hi0 hi8 + (by decide) hz (hvec 0 hi0) (hvec 8 hi8)) + -- Step 2: ntt_step v1 z 1 9 — needs v1[1], v1[9] ≤ 29439. Both via h_v1_unc. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 29439 := by + rw [h_v1_unc 1 hi1 (by decide) (by decide)]; exact hvec 1 hi1 + have h_v1_9 : (v1.elements.val[9]!).val.natAbs ≤ 29439 := by + rw [h_v1_unc 9 hi9 (by decide) (by decide)]; exact hvec 9 hi9 + obtain ⟨v2, h_v2_eq, h_v2_lift, h_v2_unc, _h_v2_i_bd, _h_v2_j_bd⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v1 z 1#usize 9#usize hi1 hi9 + (by decide) hz h_v1_1 h_v1_9) + -- Step 3: ntt_step v2 z 2 10. + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ 29439 := by + rw [h_v2_unc 2 hi2 (by decide) (by decide), + h_v1_unc 2 hi2 (by decide) (by decide)]; exact hvec 2 hi2 + have h_v2_10 : (v2.elements.val[10]!).val.natAbs ≤ 29439 := by + rw [h_v2_unc 10 hi10 (by decide) (by decide), + h_v1_unc 10 hi10 (by decide) (by decide)]; exact hvec 10 hi10 + obtain ⟨v3, h_v3_eq, h_v3_lift, h_v3_unc, _h_v3_i_bd, _h_v3_j_bd⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v2 z 2#usize 10#usize hi2 hi10 + (by decide) hz h_v2_2 h_v2_10) + -- Step 4: ntt_step v3 z 3 11. + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ 29439 := by + rw [h_v3_unc 3 hi3 (by decide) (by decide), + h_v2_unc 3 hi3 (by decide) (by decide), + h_v1_unc 3 hi3 (by decide) (by decide)]; exact hvec 3 hi3 + have h_v3_11 : (v3.elements.val[11]!).val.natAbs ≤ 29439 := by + rw [h_v3_unc 11 hi11 (by decide) (by decide), + h_v2_unc 11 hi11 (by decide) (by decide), + h_v1_unc 11 hi11 (by decide) (by decide)]; exact hvec 11 hi11 + obtain ⟨v4, h_v4_eq, h_v4_lift, h_v4_unc, _h_v4_i_bd, _h_v4_j_bd⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v3 z 3#usize 11#usize hi3 hi11 + (by decide) hz h_v3_3 h_v3_11) + -- Step 5: ntt_step v4 z 4 12. + have h_v4_4 : (v4.elements.val[4]!).val.natAbs ≤ 29439 := by + rw [h_v4_unc 4 hi4 (by decide) (by decide), + h_v3_unc 4 hi4 (by decide) (by decide), + h_v2_unc 4 hi4 (by decide) (by decide), + h_v1_unc 4 hi4 (by decide) (by decide)]; exact hvec 4 hi4 + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ 29439 := by + rw [h_v4_unc 12 hi12 (by decide) (by decide), + h_v3_unc 12 hi12 (by decide) (by decide), + h_v2_unc 12 hi12 (by decide) (by decide), + h_v1_unc 12 hi12 (by decide) (by decide)]; exact hvec 12 hi12 + obtain ⟨v5, h_v5_eq, h_v5_lift, h_v5_unc, _h_v5_i_bd, _h_v5_j_bd⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v4 z 4#usize 12#usize hi4 hi12 + (by decide) hz h_v4_4 h_v4_12) + -- Step 6: ntt_step v5 z 5 13. + have h_v5_5 : (v5.elements.val[5]!).val.natAbs ≤ 29439 := by + rw [h_v5_unc 5 hi5 (by decide) (by decide), + h_v4_unc 5 hi5 (by decide) (by decide), + h_v3_unc 5 hi5 (by decide) (by decide), + h_v2_unc 5 hi5 (by decide) (by decide), + h_v1_unc 5 hi5 (by decide) (by decide)]; exact hvec 5 hi5 + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ 29439 := by + rw [h_v5_unc 13 hi13 (by decide) (by decide), + h_v4_unc 13 hi13 (by decide) (by decide), + h_v3_unc 13 hi13 (by decide) (by decide), + h_v2_unc 13 hi13 (by decide) (by decide), + h_v1_unc 13 hi13 (by decide) (by decide)]; exact hvec 13 hi13 + obtain ⟨v6, h_v6_eq, h_v6_lift, h_v6_unc, _h_v6_i_bd, _h_v6_j_bd⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v5 z 5#usize 13#usize hi5 hi13 + (by decide) hz h_v5_5 h_v5_13) + -- Step 7: ntt_step v6 z 6 14. + have h_v6_6 : (v6.elements.val[6]!).val.natAbs ≤ 29439 := by + rw [h_v6_unc 6 hi6 (by decide) (by decide), + h_v5_unc 6 hi6 (by decide) (by decide), + h_v4_unc 6 hi6 (by decide) (by decide), + h_v3_unc 6 hi6 (by decide) (by decide), + h_v2_unc 6 hi6 (by decide) (by decide), + h_v1_unc 6 hi6 (by decide) (by decide)]; exact hvec 6 hi6 + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 29439 := by + rw [h_v6_unc 14 hi14 (by decide) (by decide), + h_v5_unc 14 hi14 (by decide) (by decide), + h_v4_unc 14 hi14 (by decide) (by decide), + h_v3_unc 14 hi14 (by decide) (by decide), + h_v2_unc 14 hi14 (by decide) (by decide), + h_v1_unc 14 hi14 (by decide) (by decide)]; exact hvec 14 hi14 + obtain ⟨v7, h_v7_eq, h_v7_lift, h_v7_unc, _h_v7_i_bd, _h_v7_j_bd⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v6 z 6#usize 14#usize hi6 hi14 + (by decide) hz h_v6_6 h_v6_14) + -- Step 8: ntt_step v7 z 7 15. + have h_v7_7 : (v7.elements.val[7]!).val.natAbs ≤ 29439 := by + rw [h_v7_unc 7 hi7 (by decide) (by decide), + h_v6_unc 7 hi7 (by decide) (by decide), + h_v5_unc 7 hi7 (by decide) (by decide), + h_v4_unc 7 hi7 (by decide) (by decide), + h_v3_unc 7 hi7 (by decide) (by decide), + h_v2_unc 7 hi7 (by decide) (by decide), + h_v1_unc 7 hi7 (by decide) (by decide)]; exact hvec 7 hi7 + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 29439 := by + rw [h_v7_unc 15 hi15 (by decide) (by decide), + h_v6_unc 15 hi15 (by decide) (by decide), + h_v5_unc 15 hi15 (by decide) (by decide), + h_v4_unc 15 hi15 (by decide) (by decide), + h_v3_unc 15 hi15 (by decide) (by decide), + h_v2_unc 15 hi15 (by decide) (by decide), + h_v1_unc 15 hi15 (by decide) (by decide)]; exact hvec 15 hi15 + obtain ⟨v8, h_v8_eq, h_v8_lift, _h_v8_unc, _h_v8_i_bd, _h_v8_j_bd⟩ := + triple_exists_ok_fc (ntt_step_pair_fc v7 z 7#usize 15#usize hi7 hi15 + (by decide) hz h_v7_7 h_v7_15) + -- Compose into a single `.ok v8` for the layer body. + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step vec z = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.ntt_layer_3_step + rw [h_v1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v5_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v6_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v7_eq]; simp only [Aeneas.Std.bind_tc_ok] + exact h_v8_eq + apply triple_of_ok_fc h_body + -- Chain the 8 lift equations into the spec composition. + unfold Spec.chunk_ntt_layer_3_step_pure + rw [h_v8_lift, h_v7_lift, h_v6_lift, h_v5_lift, h_v4_lift, h_v3_lift, h_v2_lift, h_v1_lift] + +/-- L2.5 — `inv_ntt_step`: per-pair inverse butterfly. + + **Preconditions beyond locked statement** (precondition adjustment): + - `hne : i.val ≠ j.val` — without this the impl's two writes + (`vec[i] := o0` then `vec[j] := o1`) at the same index yield `o1` + while the spec's `(a.set i new_i).set j new_j` with `i = j` also + yields `new_j`, but the lift-level proof bifurcates messily. Real + callers (inv_ntt_layer_{1,2,3}_step) all use distinct `i, j`. + - `hvec : ∀ k < 16, |vec[k]| ≤ 13312` (= 4·3328) — needed so that + `wrapping_add (vec[j], vec[i])` and `wrapping_sub (vec[j], vec[i])` + don't overflow at the I16 level. Since `|vec[j]| + |vec[i]| ≤ + 26624 < 32768`, both ops have `.val = b + a` and `b - a` exactly. + This mirrors the legacy `inv_ntt_step_spec_B` with `B = 4`. -/ +@[spec] +theorem inv_ntt_step_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i j : Std.Usize) + (hi : i.val < 16) (hj : j.val < 16) + (hne : i.val ≠ j.val) + (hzeta : zeta.val.natAbs ≤ 1664) + (hvec : ∀ k : Nat, k < 16 → + (vec.elements.val[k]!).val.natAbs ≤ 13312) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step vec zeta i j + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_inv_ntt_step_pure (lift_chunk vec) (lift_fe_mont zeta) i j ⌝ ⦄ := by + -- Step 0: vector length facts. + have h_vec_len : vec.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length vec + have h_vec_val_len : vec.elements.val.length = 16 := h_vec_len + -- Step 1: read vec[j] (= i1 in impl, called "b"). + have h_idx_j : + Aeneas.Std.Array.index_usize vec.elements j = .ok (vec.elements.val[j.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec.elements j + (by rw [h_vec_len]; exact hj) + -- Step 2: read vec[i] (= i2 in impl, called "a"). + have h_idx_i : + Aeneas.Std.Array.index_usize vec.elements i = .ok (vec.elements.val[i.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec.elements i + (by rw [h_vec_len]; exact hi) + set a : Std.I16 := vec.elements.val[i.val]! with ha_def + set b : Std.I16 := vec.elements.val[j.val]! with hb_def + have h_a_bnd : a.val.natAbs ≤ 13312 := hvec i.val hi + have h_b_bnd : b.val.natAbs ≤ 13312 := hvec j.val hj + -- Step 3,4: wrapping_sub b a and wrapping_add b a. + have h_sub_eq : + CoreModels.core.num.I16.wrapping_sub b a = .ok (Std.I16.wrapping_sub b a) := + ntt_step_fc.cm_wrapping_sub_ok_eq b a + have h_add_eq : + CoreModels.core.num.I16.wrapping_add b a = .ok (Std.I16.wrapping_add b a) := + ntt_step_fc.cm_wrapping_add_ok_eq b a + set a_minus_b : Std.I16 := Std.I16.wrapping_sub b a with hamb_def + set a_plus_b : Std.I16 := Std.I16.wrapping_add b a with hapb_def + -- No-overflow for wrapping_add b a: |b.val + a.val| ≤ 2·13312 = 26624 < 32768. + have h_apb_val : a_plus_b.val = b.val + a.val := by + have h_sum_abs : ((b.val + a.val : Int)).natAbs ≤ 26624 := by + have h_tri : (b.val + a.val).natAbs ≤ b.val.natAbs + a.val.natAbs := + Int.natAbs_add_le _ _ + omega + have h_lb : -(2 ^ 15 : Int) ≤ b.val + a.val := by omega + have h_ub : b.val + a.val < (2 ^ 15 : Int) := by omega + have h_bmod : Int.bmod (b.val + a.val) (2 ^ 16) = b.val + a.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_add_val_eq b a + rw [hapb_def, h_val, h_bmod] + have h_amb_val : a_minus_b.val = b.val - a.val := by + have h_diff_abs : ((b.val - a.val : Int)).natAbs ≤ 26624 := by + have h_neg_natAbs : (-a.val).natAbs = a.val.natAbs := Int.natAbs_neg _ + have h_eq : b.val - a.val = b.val + (-a.val) := by ring + rw [h_eq] + have h_tri : (b.val + (-a.val)).natAbs ≤ b.val.natAbs + (-a.val).natAbs := + Int.natAbs_add_le _ _ + rw [h_neg_natAbs] at h_tri + omega + have h_lb : -(2 ^ 15 : Int) ≤ b.val - a.val := by omega + have h_ub : b.val - a.val < (2 ^ 15 : Int) := by omega + have h_bmod : Int.bmod (b.val - a.val) (2 ^ 16) = b.val - a.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_sub_val_eq b a + rw [hamb_def, h_val, h_bmod] + -- Bound on a_plus_b for L0.2 (≤ 26624 ≤ 32767). + have h_apb_bd : a_plus_b.val.natAbs ≤ 32767 := by + rw [h_apb_val] + have h_tri : (b.val + a.val).natAbs ≤ b.val.natAbs + a.val.natAbs := + Int.natAbs_add_le _ _ + omega + -- Step 5: L0.2 barrett_reduce_element on a_plus_b. + obtain ⟨o0, h_o0_eq_ok, h_o0_bd, h_o0_lift⟩ := + triple_exists_ok_fc (barrett_reduce_element_fc a_plus_b h_apb_bd) + -- Recover modq form via legacy (needed since L0.2-FC delivers `lift_fe o0 = + -- barrett_pure (lift_fe a_plus_b)` but we need `lift_fe o0 = add_pure (lift_fe b) + -- (lift_fe a)`; the bridge needs the modq equation on `.val`s). + obtain ⟨o0', h_o0'_eq, h_o0'_modq, _h_o0'_bd⟩ := + triple_exists_ok_fc + (libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.barrett_reduce_element_spec a_plus_b h_apb_bd) + have h_oo' : o0 = o0' := by + have : (Result.ok o0 : Result _) = Result.ok o0' := by + rw [← h_o0_eq_ok, h_o0'_eq] + cases this; rfl + -- Step 6: classify zeta = zeta. + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify zeta = .ok zeta := + ntt_step_fc.classify_ok_eq zeta + -- Step 7: L0.4 montgomery_multiply on (a_minus_b, zeta). + obtain ⟨o1, h_o1_eq_ok, h_o1_bd, h_o1_lift⟩ := + triple_exists_ok_fc (montgomery_multiply_fe_by_fer_fc a_minus_b zeta + (by have := a_minus_b.hBounds; omega) hzeta) + obtain ⟨o1', h_o1'_eq, h_o1'_bd_tight, h_o1'_modq⟩ := + triple_exists_ok_fc + (libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_multiply_fe_by_fer_spec a_minus_b zeta hzeta) + have h_oo1' : o1 = o1' := by + have : (Result.ok o1 : Result _) = Result.ok o1' := by + rw [← h_o1_eq_ok, h_o1'_eq] + cases this; rfl + -- Step 8: write vec[i] := o0. + have h_upd_i : + Aeneas.Std.Array.update vec.elements i o0 + = .ok (vec.elements.set i o0) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq vec.elements i o0 + (by rw [h_vec_len]; exact hi) + -- Step 9: write vec[j] := o1. + have h_upd_j : + Aeneas.Std.Array.update (vec.elements.set i o0) j o1 + = .ok ((vec.elements.set i o0).set j o1) := by + have h_len : (vec.elements.set i o0).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq _ j o1 + (by rw [h_len]; exact hj) + -- Compose into `.ok final_vec`. + set final_elements : Std.Array Std.I16 16#usize := + (vec.elements.set i o0).set j o1 with hfe_def + set final_vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { elements := final_elements } with hfv_def + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step vec zeta i j + = .ok final_vec := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step + rw [h_idx_j]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_i]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_sub_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_add_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [← h_oo'] at h_o0'_eq + rw [h_o0_eq_ok]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_classify]; simp only [Aeneas.Std.bind_tc_ok] + rw [← h_oo1'] at h_o1'_eq + rw [h_o1_eq_ok]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_i]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_j]; simp only [Aeneas.Std.bind_tc_ok]; rfl + apply triple_of_ok_fc h_body + -- Now: prove the FC chunk equation. + -- spec new_i := add_pure (lift_fe b) (lift_fe a) + -- spec new_j := mul_pure (sub_pure (lift_fe b) (lift_fe a)) (lift_fe_mont zeta) + set s_new_i : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe b) (lift_fe a) with hs_new_i_def + set s_diff : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe b) (lift_fe a) with hs_diff_def + set s_new_j : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + s_diff (lift_fe_mont zeta) with hs_new_j_def + unfold lift_chunk Spec.chunk_inv_ntt_step_pure + apply Subtype.ext + simp only [Std.Array.set_val_eq] + -- Bridge: (vec.elements.val.map lift_fe)[k]! = lift_fe (vec.elements.val[k]!) when k < 16. + have h_map_lift_at (k : Nat) (hk : k < 16) : + (vec.elements.val.map lift_fe)[k]! = lift_fe (vec.elements.val[k]!) := by + have hk_lhs : k < (vec.elements.val.map lift_fe).length := by + simp [List.length_map, h_vec_val_len]; exact hk + rw [getElem!_pos (vec.elements.val.map lift_fe) k hk_lhs] + rw [List.getElem_map] + have hk_vec : k < vec.elements.val.length := by rw [h_vec_val_len]; exact hk + rw [getElem!_pos vec.elements.val k hk_vec] + show ((vec.elements.val.set i.val o0).set j.val o1).map lift_fe + = ((vec.elements.val.map lift_fe).set i.val + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((vec.elements.val.map lift_fe)[j.val]!) + ((vec.elements.val.map lift_fe)[i.val]!))).set j.val + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((vec.elements.val.map lift_fe)[j.val]!) + ((vec.elements.val.map lift_fe)[i.val]!)) + (lift_fe_mont zeta)) + rw [h_map_lift_at i.val hi, h_map_lift_at j.val hj] + change ((vec.elements.val.set i.val o0).set j.val o1).map lift_fe + = ((vec.elements.val.map lift_fe).set i.val s_new_i).set j.val s_new_j + apply List.ext_getElem + · simp [List.length_map, List.length_set] + · intro k hk1 hk2 + have hk : k < 16 := by + have hk' : k < (((vec.elements.val.set i.val o0).set j.val o1).map lift_fe).length := hk1 + simp [List.length_map, List.length_set, h_vec_val_len] at hk' + exact hk' + rw [List.getElem_map] + by_cases h_eq_j : k = j.val + · -- k = j.val: r[j] = o1 = mont_mul(b-a, zeta). + subst h_eq_j + rw [List.getElem_set_self] + rw [List.getElem_set_self] + show lift_fe o1 = s_new_j + -- mont_mul a_minus_b zeta produced o1. We have h_o1'_modq: + -- modq_eq o1'.val (a_minus_b.val * zeta.val * 169) 3329. + -- lift_fe o1 = lift_fe o1' (h_oo1') = mul_pure (lift_fe a_minus_b) (lift_fe_mont zeta). + have h_step1 : + lift_fe o1 = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe a_minus_b) (lift_fe_mont zeta) := by + rw [h_oo1'] + exact lift_fe_mul_pure_mont_eq a_minus_b zeta o1' h_o1'_modq + rw [h_step1] + -- Now: mul_pure (lift_fe a_minus_b) (lift_fe_mont zeta) = s_new_j + -- = mul_pure s_diff (lift_fe_mont zeta) where s_diff = sub_pure (lift_fe b) (lift_fe a). + -- Reduce by congr 1 to: lift_fe a_minus_b = s_diff. + simp only [hs_new_j_def] + congr 1 + -- lift_fe a_minus_b = sub_pure (lift_fe b) (lift_fe a). + exact lift_fe_sub_pure_eq b a a_minus_b h_amb_val + · rw [List.getElem_set_ne (Ne.symm h_eq_j)] + rw [List.getElem_set_ne (Ne.symm h_eq_j)] + by_cases h_eq_i : k = i.val + · -- k = i.val: r[i] = o0 = barrett(b+a). + subst h_eq_i + rw [List.getElem_set_self] + rw [List.getElem_set_self] + show lift_fe o0 = s_new_i + -- lift_fe o0 = lift_fe o0' (h_oo') from h_o0'_modq: + -- modq_eq o0'.val a_plus_b.val 3329. + -- Then lift_fe o0' = lift_fe a_plus_b = add_pure (lift_fe b) (lift_fe a). + have h_step1 : lift_fe o0 = lift_fe a_plus_b := by + rw [h_oo'] + exact lift_fe_eq_of_modq o0' a_plus_b h_o0'_modq + rw [h_step1] + -- lift_fe a_plus_b = add_pure (lift_fe b) (lift_fe a) via h_apb_val. + simp only [hs_new_i_def] + exact lift_fe_add_pure_eq b a a_plus_b h_apb_val + · -- k ≠ i.val, k ≠ j.val. + rw [List.getElem_set_ne (Ne.symm h_eq_i)] + rw [List.getElem_set_ne (Ne.symm h_eq_i)] + rw [List.getElem_map] + +/-! ### L2.9 — `inv_ntt_layer_1_step` (FC). + + layer-1 vector-level step. Mirrors `ntt_layer_1_step_fc` on the same lane-pair sequence + `(0,2)(1,3)(4,6)(5,7)(8,10)(9,11)(12,14)(13,15)` with zetas + `z0,z0,z1,z1,z2,z2,z3,z3` — only the butterfly direction differs. + Chains 8 `inv_ntt_step` calls via the private `inv_ntt_step_pair_fc` + helper (mirror of `ntt_step_pair_fc`) which exposes + both the `lift_chunk` equation AND the unchanged-lane preservation, + plus the tight per-output bound `≤ 3328` so the bound is preserved + across the 8-step chain on disjoint lane pairs. -/ + +/-- Per-lane variant of `inv_ntt_step_fc` for layer composition. Splits + the universal precondition into per-lane bounds on `i` and `j` (only + the two lanes actually read), and exposes: + 1. The `lift_chunk` equation (the spec-bridge). + 2. **Unchanged-lane preservation**: `r[k] = vec[k]` for `k ≠ i, j`. + 3. **Tight per-output bound `≤ 3328`** at both `i` and `j`. `r[i]` is + `barrett(vec[j] + vec[i])` (post-barrett bound is `≤ 3328`); `r[j]` + is `montgomery_multiply(vec[j] - vec[i], zeta)` whose Equivalence + spec also gives the tight `≤ 3328` bound. -/ +theorem inv_ntt_step_pair_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (zeta : Std.I16) (i j : Std.Usize) + (hi : i.val < 16) (hj : j.val < 16) + (hne : i.val ≠ j.val) + (hzeta : zeta.val.natAbs ≤ 1664) + (h_a_bnd : (vec.elements.val[i.val]!).val.natAbs ≤ 13312) + (h_b_bnd : (vec.elements.val[j.val]!).val.natAbs ≤ 13312) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step vec zeta i j + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_inv_ntt_step_pure (lift_chunk vec) (lift_fe_mont zeta) i j + ∧ (∀ k : Nat, k < 16 → k ≠ i.val → k ≠ j.val → + (r.elements.val[k]!) = (vec.elements.val[k]!)) + ∧ (r.elements.val[i.val]!).val.natAbs ≤ 3328 + ∧ (r.elements.val[j.val]!).val.natAbs ≤ 3328 ⌝ ⦄ := by + -- Step 0: vector length facts. + have h_vec_len : vec.elements.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length vec + have h_vec_val_len : vec.elements.val.length = 16 := h_vec_len + -- Step 1: read vec[j] (= i1 in impl, called "b"). + have h_idx_j : + Aeneas.Std.Array.index_usize vec.elements j = .ok (vec.elements.val[j.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec.elements j + (by rw [h_vec_len]; exact hj) + -- Step 2: read vec[i] (= i2 in impl, called "a"). + have h_idx_i : + Aeneas.Std.Array.index_usize vec.elements i = .ok (vec.elements.val[i.val]!) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_index_usize_ok_eq vec.elements i + (by rw [h_vec_len]; exact hi) + set a : Std.I16 := vec.elements.val[i.val]! with ha_def + set b : Std.I16 := vec.elements.val[j.val]! with hb_def + -- Step 3,4: wrapping_sub b a and wrapping_add b a. + have h_sub_eq : + CoreModels.core.num.I16.wrapping_sub b a = .ok (Std.I16.wrapping_sub b a) := + ntt_step_fc.cm_wrapping_sub_ok_eq b a + have h_add_eq : + CoreModels.core.num.I16.wrapping_add b a = .ok (Std.I16.wrapping_add b a) := + ntt_step_fc.cm_wrapping_add_ok_eq b a + set a_minus_b : Std.I16 := Std.I16.wrapping_sub b a with hamb_def + set a_plus_b : Std.I16 := Std.I16.wrapping_add b a with hapb_def + -- No-overflow for wrapping_add b a: |b.val + a.val| ≤ 2·13312 = 26624 < 32768. + have h_apb_val : a_plus_b.val = b.val + a.val := by + have h_sum_abs : ((b.val + a.val : Int)).natAbs ≤ 26624 := by + have h_tri : (b.val + a.val).natAbs ≤ b.val.natAbs + a.val.natAbs := + Int.natAbs_add_le _ _ + omega + have h_lb : -(2 ^ 15 : Int) ≤ b.val + a.val := by omega + have h_ub : b.val + a.val < (2 ^ 15 : Int) := by omega + have h_bmod : Int.bmod (b.val + a.val) (2 ^ 16) = b.val + a.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_add_val_eq b a + rw [hapb_def, h_val, h_bmod] + have h_amb_val : a_minus_b.val = b.val - a.val := by + have h_diff_abs : ((b.val - a.val : Int)).natAbs ≤ 26624 := by + have h_neg_natAbs : (-a.val).natAbs = a.val.natAbs := Int.natAbs_neg _ + have h_eq : b.val - a.val = b.val + (-a.val) := by ring + rw [h_eq] + have h_tri : (b.val + (-a.val)).natAbs ≤ b.val.natAbs + (-a.val).natAbs := + Int.natAbs_add_le _ _ + rw [h_neg_natAbs] at h_tri + omega + have h_lb : -(2 ^ 15 : Int) ≤ b.val - a.val := by omega + have h_ub : b.val - a.val < (2 ^ 15 : Int) := by omega + have h_bmod : Int.bmod (b.val - a.val) (2 ^ 16) = b.val - a.val := by + apply Aeneas.Arith.Int.bmod_pow2_eq_of_inBounds' 16 _ (by decide) + · have h_const : -((2 : Int) ^ (16 - 1)) ≤ -(2 ^ 15 : Int) := by decide + exact le_trans h_const h_lb + · have h_const : (2 ^ 15 : Int) ≤ (2 : Int) ^ (16 - 1) := by decide + exact lt_of_lt_of_le h_ub h_const + have h_val := Std.I16.wrapping_sub_val_eq b a + rw [hamb_def, h_val, h_bmod] + -- Bound on a_plus_b for L0.2 (≤ 26624 ≤ 32767). + have h_apb_bd : a_plus_b.val.natAbs ≤ 32767 := by + rw [h_apb_val] + have h_tri : (b.val + a.val).natAbs ≤ b.val.natAbs + a.val.natAbs := + Int.natAbs_add_le _ _ + omega + -- Step 5: L0.2 barrett_reduce_element on a_plus_b. Bound: |o0| ≤ 3328. + obtain ⟨o0, h_o0_eq_ok, h_o0_bd, _h_o0_lift⟩ := + triple_exists_ok_fc (barrett_reduce_element_fc a_plus_b h_apb_bd) + obtain ⟨o0', h_o0'_eq, h_o0'_modq, _h_o0'_bd⟩ := + triple_exists_ok_fc + (libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.barrett_reduce_element_spec a_plus_b h_apb_bd) + have h_oo' : o0 = o0' := by + have : (Result.ok o0 : Result _) = Result.ok o0' := by + rw [← h_o0_eq_ok, h_o0'_eq] + cases this; rfl + -- Step 6: classify zeta = zeta. + have h_classify : libcrux_secrets.traits.Classify.Blanket.classify zeta = .ok zeta := + ntt_step_fc.classify_ok_eq zeta + -- Step 7: L0.4 montgomery_multiply on (a_minus_b, zeta). Bound: |o1| ≤ 3328+1665 = 4993. + obtain ⟨o1, h_o1_eq_ok, h_o1_bd, _h_o1_lift⟩ := + triple_exists_ok_fc (montgomery_multiply_fe_by_fer_fc a_minus_b zeta + (by have := a_minus_b.hBounds; omega) hzeta) + obtain ⟨o1', h_o1'_eq, h_o1'_bd_tight, h_o1'_modq⟩ := + triple_exists_ok_fc + (libcrux_iot_ml_kem.Vector.Portable.Arithmetic.PerElement.montgomery_multiply_fe_by_fer_spec a_minus_b zeta hzeta) + have h_oo1' : o1 = o1' := by + have : (Result.ok o1 : Result _) = Result.ok o1' := by + rw [← h_o1_eq_ok, h_o1'_eq] + cases this; rfl + -- Step 8: write vec[i] := o0. + have h_upd_i : + Aeneas.Std.Array.update vec.elements i o0 + = .ok (vec.elements.set i o0) := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq vec.elements i o0 + (by rw [h_vec_len]; exact hi) + -- Step 9: write vec[j] := o1. + have h_upd_j : + Aeneas.Std.Array.update (vec.elements.set i o0) j o1 + = .ok ((vec.elements.set i o0).set j o1) := by + have h_len : (vec.elements.set i o0).length = 16 := by + rw [Std.Array.set_length]; exact h_vec_len + exact libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.array_update_ok_eq _ j o1 + (by rw [h_len]; exact hj) + -- Compose into `.ok final_vec`. + set final_elements : Std.Array Std.I16 16#usize := + (vec.elements.set i o0).set j o1 with hfe_def + set final_vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector := + { elements := final_elements } with hfv_def + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step vec zeta i j + = .ok final_vec := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_step + rw [h_idx_j]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_idx_i]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_sub_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_add_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [← h_oo'] at h_o0'_eq + rw [h_o0_eq_ok]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_classify]; simp only [Aeneas.Std.bind_tc_ok] + rw [← h_oo1'] at h_o1'_eq + rw [h_o1_eq_ok]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_i]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_upd_j]; simp only [Aeneas.Std.bind_tc_ok]; rfl + apply triple_of_ok_fc h_body + -- Helper: bridge for `(vec.elements.val.map lift_fe)[k]! = lift_fe (vec.elements.val[k]!)`. + have h_map_lift_at (k : Nat) (hk : k < 16) : + (vec.elements.val.map lift_fe)[k]! = lift_fe (vec.elements.val[k]!) := by + have hk_lhs : k < (vec.elements.val.map lift_fe).length := by + simp [List.length_map, h_vec_val_len]; exact hk + rw [getElem!_pos (vec.elements.val.map lift_fe) k hk_lhs] + rw [List.getElem_map] + have hk_vec : k < vec.elements.val.length := by rw [h_vec_val_len]; exact hk + rw [getElem!_pos vec.elements.val k hk_vec] + -- Now: 4 conjuncts. + refine ⟨?_, ?_, ?_, ?_⟩ + · -- lift_chunk equation: same as keystone proof. + set s_new_i : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + (lift_fe b) (lift_fe a) with hs_new_i_def + set s_diff : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + (lift_fe b) (lift_fe a) with hs_diff_def + set s_new_j : hacspec_ml_kem.parameters.FieldElement := + libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + s_diff (lift_fe_mont zeta) with hs_new_j_def + unfold lift_chunk Spec.chunk_inv_ntt_step_pure + apply Subtype.ext + simp only [Std.Array.set_val_eq] + show ((vec.elements.val.set i.val o0).set j.val o1).map lift_fe + = ((vec.elements.val.map lift_fe).set i.val + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.add_pure + ((vec.elements.val.map lift_fe)[j.val]!) + ((vec.elements.val.map lift_fe)[i.val]!))).set j.val + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (libcrux_iot_ml_kem.Spec.Pure.FieldElement.sub_pure + ((vec.elements.val.map lift_fe)[j.val]!) + ((vec.elements.val.map lift_fe)[i.val]!)) + (lift_fe_mont zeta)) + rw [h_map_lift_at i.val hi, h_map_lift_at j.val hj] + change ((vec.elements.val.set i.val o0).set j.val o1).map lift_fe + = ((vec.elements.val.map lift_fe).set i.val s_new_i).set j.val s_new_j + apply List.ext_getElem + · simp [List.length_map, List.length_set] + · intro k hk1 hk2 + have hk : k < 16 := by + have hk' : k < (((vec.elements.val.set i.val o0).set j.val o1).map lift_fe).length := hk1 + simp [List.length_map, List.length_set, h_vec_val_len] at hk' + exact hk' + rw [List.getElem_map] + by_cases h_eq_j : k = j.val + · subst h_eq_j + rw [List.getElem_set_self] + rw [List.getElem_set_self] + show lift_fe o1 = s_new_j + have h_step1 : + lift_fe o1 = libcrux_iot_ml_kem.Spec.Pure.FieldElement.mul_pure + (lift_fe a_minus_b) (lift_fe_mont zeta) := by + rw [h_oo1'] + exact lift_fe_mul_pure_mont_eq a_minus_b zeta o1' h_o1'_modq + rw [h_step1] + simp only [hs_new_j_def] + congr 1 + exact lift_fe_sub_pure_eq b a a_minus_b h_amb_val + · rw [List.getElem_set_ne (Ne.symm h_eq_j)] + rw [List.getElem_set_ne (Ne.symm h_eq_j)] + by_cases h_eq_i : k = i.val + · subst h_eq_i + rw [List.getElem_set_self] + rw [List.getElem_set_self] + show lift_fe o0 = s_new_i + have h_step1 : lift_fe o0 = lift_fe a_plus_b := by + rw [h_oo'] + exact lift_fe_eq_of_modq o0' a_plus_b h_o0'_modq + rw [h_step1] + simp only [hs_new_i_def] + exact lift_fe_add_pure_eq b a a_plus_b h_apb_val + · rw [List.getElem_set_ne (Ne.symm h_eq_i)] + rw [List.getElem_set_ne (Ne.symm h_eq_i)] + rw [List.getElem_map] + · -- Untouched-lane preservation: r[k] = vec[k] for k ≠ i, j. + -- final_vec.elements.val = (vec.elements.set i o0).set j o1 .val + intro k hk hki hkj + show ((vec.elements.set i o0).set j o1).val[k]! + = vec.elements.val[k]! + have h_set_val_eq : ((vec.elements.set i o0).set j o1).val + = (vec.elements.val.set i.val o0).set j.val o1 := by + simp [Std.Array.set_val_eq] + rw [h_set_val_eq] + rw [getElem!_pos _ k (by simp [List.length_set, h_vec_val_len]; exact hk)] + rw [List.getElem_set_ne (Ne.symm hkj)] + rw [List.getElem_set_ne (Ne.symm hki)] + rw [getElem!_pos vec.elements.val k (by rw [h_vec_val_len]; exact hk)] + · -- Bound at i: r[i] = o0 (set last) — since i ≠ j (hne), second set doesn't touch i, + -- so r[i] = (set i o0)[i] = o0. |o0| ≤ 3328 (tight, from barrett_reduce). + show ((vec.elements.set i o0).set j o1).val[i.val]!.val.natAbs ≤ 3328 + have h_set_val_eq : ((vec.elements.set i o0).set j o1).val + = (vec.elements.val.set i.val o0).set j.val o1 := by + simp [Std.Array.set_val_eq] + rw [h_set_val_eq] + rw [getElem!_pos _ i.val (by simp [List.length_set, h_vec_val_len]; exact hi)] + rw [List.getElem_set_ne (Ne.symm hne)] + rw [List.getElem_set_self] + -- now goal: o0.val.natAbs ≤ 3328. h_o0_bd : o0.val.natAbs ≤ 3328. + exact h_o0_bd + · -- Bound at j: r[j] = o1 (second set wins). |o1| ≤ 3328 (tight, via + -- `montgomery_multiply_fe_by_fer_spec`). + show ((vec.elements.set i o0).set j o1).val[j.val]!.val.natAbs ≤ 3328 + have h_set_val_eq : ((vec.elements.set i o0).set j o1).val + = (vec.elements.val.set i.val o0).set j.val o1 := by + simp [Std.Array.set_val_eq] + rw [h_set_val_eq] + rw [getElem!_pos _ j.val (by simp [List.length_set, h_vec_val_len]; exact hj)] + rw [List.getElem_set_self] + -- goal: o1.val.natAbs ≤ 3328. h_o1'_bd_tight : o1'.val.natAbs ≤ 3328 with o1 = o1'. + rw [h_oo1']; exact h_o1'_bd_tight + +/-- L2.9 — `inv_ntt_layer_1_step`: vector-level layer-1 inverse step. + Maps `lift_chunk` of the impl output to `Spec.chunk_inv_ntt_layer_1_step_pure` + applied to `lift_chunk` of the input and the canonical-domain zetas. + + **Precondition adjustment** (beyond locked statement): + - `hz : |z_k| ≤ 1664` for each zeta — Mont-domain zeta from `polynomial.zeta`. + - `hvec : ∀ k < 16, |vec[k]| ≤ 13312` — preserved across 8 sequential + `inv_ntt_step` invocations on disjoint pairs (each lane after a step is + either ≤ 3328 from barrett, ≤ 4993 from mont-mul, or unchanged ≤ 13312). -/ +@[spec] +theorem inv_ntt_layer_1_step_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (z0 z1 z2 z3 : Std.I16) + (hz : z0.val.natAbs ≤ 1664 ∧ z1.val.natAbs ≤ 1664 + ∧ z2.val.natAbs ≤ 1664 ∧ z3.val.natAbs ≤ 1664) + (hvec : ∀ k : Nat, k < 16 → + (vec.elements.val[k]!).val.natAbs ≤ 13312) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_1_step vec z0 z1 z2 z3 + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_inv_ntt_layer_1_step_pure (lift_chunk vec) + (lift_fe_mont z0) (lift_fe_mont z1) + (lift_fe_mont z2) (lift_fe_mont z3) + ∧ (∀ k : Nat, k < 16 → + (r.elements.val[k]!).val.natAbs ≤ 3328) ⌝ ⦄ := by + obtain ⟨hz0, hz1, hz2, hz3⟩ := hz + have hi0 : (0 : Nat) < 16 := by decide + have hi1 : (1 : Nat) < 16 := by decide + have hi2 : (2 : Nat) < 16 := by decide + have hi3 : (3 : Nat) < 16 := by decide + have hi4 : (4 : Nat) < 16 := by decide + have hi5 : (5 : Nat) < 16 := by decide + have hi6 : (6 : Nat) < 16 := by decide + have hi7 : (7 : Nat) < 16 := by decide + have hi8 : (8 : Nat) < 16 := by decide + have hi9 : (9 : Nat) < 16 := by decide + have hi10 : (10 : Nat) < 16 := by decide + have hi11 : (11 : Nat) < 16 := by decide + have hi12 : (12 : Nat) < 16 := by decide + have hi13 : (13 : Nat) < 16 := by decide + have hi14 : (14 : Nat) < 16 := by decide + have hi15 : (15 : Nat) < 16 := by decide + -- Step 1: inv_ntt_step vec z0 0 2. + obtain ⟨v1, h_v1_eq, h_v1_lift, h_v1_unc, h_v1_bnd_0, h_v1_bnd_2⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc vec z0 0#usize 2#usize hi0 hi2 + (by decide) hz0 (hvec 0 hi0) (hvec 2 hi2)) + -- Step 2: inv_ntt_step v1 z0 1 3. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 13312 := by + rw [h_v1_unc 1 hi1 (by decide) (by decide)]; exact hvec 1 hi1 + have h_v1_3 : (v1.elements.val[3]!).val.natAbs ≤ 13312 := by + rw [h_v1_unc 3 hi3 (by decide) (by decide)]; exact hvec 3 hi3 + obtain ⟨v2, h_v2_eq, h_v2_lift, h_v2_unc, h_v2_bnd_1, h_v2_bnd_3⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v1 z0 1#usize 3#usize hi1 hi3 + (by decide) hz0 h_v1_1 h_v1_3) + -- Step 3: inv_ntt_step v2 z1 4 6. + have h_v2_4 : (v2.elements.val[4]!).val.natAbs ≤ 13312 := by + rw [h_v2_unc 4 hi4 (by decide) (by decide), + h_v1_unc 4 hi4 (by decide) (by decide)]; exact hvec 4 hi4 + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ 13312 := by + rw [h_v2_unc 6 hi6 (by decide) (by decide), + h_v1_unc 6 hi6 (by decide) (by decide)]; exact hvec 6 hi6 + obtain ⟨v3, h_v3_eq, h_v3_lift, h_v3_unc, h_v3_bnd_4, h_v3_bnd_6⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v2 z1 4#usize 6#usize hi4 hi6 + (by decide) hz1 h_v2_4 h_v2_6) + -- Step 4: inv_ntt_step v3 z1 5 7. + have h_v3_5 : (v3.elements.val[5]!).val.natAbs ≤ 13312 := by + rw [h_v3_unc 5 hi5 (by decide) (by decide), + h_v2_unc 5 hi5 (by decide) (by decide), + h_v1_unc 5 hi5 (by decide) (by decide)]; exact hvec 5 hi5 + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ 13312 := by + rw [h_v3_unc 7 hi7 (by decide) (by decide), + h_v2_unc 7 hi7 (by decide) (by decide), + h_v1_unc 7 hi7 (by decide) (by decide)]; exact hvec 7 hi7 + obtain ⟨v4, h_v4_eq, h_v4_lift, h_v4_unc, h_v4_bnd_5, h_v4_bnd_7⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v3 z1 5#usize 7#usize hi5 hi7 + (by decide) hz1 h_v3_5 h_v3_7) + -- Step 5: inv_ntt_step v4 z2 8 10. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ 13312 := by + rw [h_v4_unc 8 hi8 (by decide) (by decide), + h_v3_unc 8 hi8 (by decide) (by decide), + h_v2_unc 8 hi8 (by decide) (by decide), + h_v1_unc 8 hi8 (by decide) (by decide)]; exact hvec 8 hi8 + have h_v4_10 : (v4.elements.val[10]!).val.natAbs ≤ 13312 := by + rw [h_v4_unc 10 hi10 (by decide) (by decide), + h_v3_unc 10 hi10 (by decide) (by decide), + h_v2_unc 10 hi10 (by decide) (by decide), + h_v1_unc 10 hi10 (by decide) (by decide)]; exact hvec 10 hi10 + obtain ⟨v5, h_v5_eq, h_v5_lift, h_v5_unc, h_v5_bnd_8, h_v5_bnd_10⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v4 z2 8#usize 10#usize hi8 hi10 + (by decide) hz2 h_v4_8 h_v4_10) + -- Step 6: inv_ntt_step v5 z2 9 11. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ 13312 := by + rw [h_v5_unc 9 hi9 (by decide) (by decide), + h_v4_unc 9 hi9 (by decide) (by decide), + h_v3_unc 9 hi9 (by decide) (by decide), + h_v2_unc 9 hi9 (by decide) (by decide), + h_v1_unc 9 hi9 (by decide) (by decide)]; exact hvec 9 hi9 + have h_v5_11 : (v5.elements.val[11]!).val.natAbs ≤ 13312 := by + rw [h_v5_unc 11 hi11 (by decide) (by decide), + h_v4_unc 11 hi11 (by decide) (by decide), + h_v3_unc 11 hi11 (by decide) (by decide), + h_v2_unc 11 hi11 (by decide) (by decide), + h_v1_unc 11 hi11 (by decide) (by decide)]; exact hvec 11 hi11 + obtain ⟨v6, h_v6_eq, h_v6_lift, h_v6_unc, h_v6_bnd_9, h_v6_bnd_11⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v5 z2 9#usize 11#usize hi9 hi11 + (by decide) hz2 h_v5_9 h_v5_11) + -- Step 7: inv_ntt_step v6 z3 12 14. + have h_v6_12 : (v6.elements.val[12]!).val.natAbs ≤ 13312 := by + rw [h_v6_unc 12 hi12 (by decide) (by decide), + h_v5_unc 12 hi12 (by decide) (by decide), + h_v4_unc 12 hi12 (by decide) (by decide), + h_v3_unc 12 hi12 (by decide) (by decide), + h_v2_unc 12 hi12 (by decide) (by decide), + h_v1_unc 12 hi12 (by decide) (by decide)]; exact hvec 12 hi12 + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 13312 := by + rw [h_v6_unc 14 hi14 (by decide) (by decide), + h_v5_unc 14 hi14 (by decide) (by decide), + h_v4_unc 14 hi14 (by decide) (by decide), + h_v3_unc 14 hi14 (by decide) (by decide), + h_v2_unc 14 hi14 (by decide) (by decide), + h_v1_unc 14 hi14 (by decide) (by decide)]; exact hvec 14 hi14 + obtain ⟨v7, h_v7_eq, h_v7_lift, h_v7_unc, h_v7_bnd_12, h_v7_bnd_14⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v6 z3 12#usize 14#usize hi12 hi14 + (by decide) hz3 h_v6_12 h_v6_14) + -- Step 8: inv_ntt_step v7 z3 13 15. + have h_v7_13 : (v7.elements.val[13]!).val.natAbs ≤ 13312 := by + rw [h_v7_unc 13 hi13 (by decide) (by decide), + h_v6_unc 13 hi13 (by decide) (by decide), + h_v5_unc 13 hi13 (by decide) (by decide), + h_v4_unc 13 hi13 (by decide) (by decide), + h_v3_unc 13 hi13 (by decide) (by decide), + h_v2_unc 13 hi13 (by decide) (by decide), + h_v1_unc 13 hi13 (by decide) (by decide)]; exact hvec 13 hi13 + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 13312 := by + rw [h_v7_unc 15 hi15 (by decide) (by decide), + h_v6_unc 15 hi15 (by decide) (by decide), + h_v5_unc 15 hi15 (by decide) (by decide), + h_v4_unc 15 hi15 (by decide) (by decide), + h_v3_unc 15 hi15 (by decide) (by decide), + h_v2_unc 15 hi15 (by decide) (by decide), + h_v1_unc 15 hi15 (by decide) (by decide)]; exact hvec 15 hi15 + obtain ⟨v8, h_v8_eq, h_v8_lift, h_v8_unc, h_v8_bnd_13, h_v8_bnd_15⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v7 z3 13#usize 15#usize hi13 hi15 + (by decide) hz3 h_v7_13 h_v7_15) + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_1_step vec z0 z1 z2 z3 + = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_1_step + rw [h_v1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v5_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v6_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v7_eq]; simp only [Aeneas.Std.bind_tc_ok] + exact h_v8_eq + apply triple_of_ok_fc h_body + refine ⟨?_, ?_⟩ + · -- Spec equation. + unfold Spec.chunk_inv_ntt_layer_1_step_pure + rw [h_v8_lift, h_v7_lift, h_v6_lift, h_v5_lift, h_v4_lift, h_v3_lift, h_v2_lift, h_v1_lift] + · -- Per-lane output bound ≤ 3328. + -- Each lane is touched by exactly one of the 8 steps. After that step, + -- the lane's natAbs is ≤ 3328 (per inv_ntt_step_pair_fc's strengthened + -- POST). Later steps don't touch it, so v8.elements.val[k]! preserves + -- that value via the unchanged-lane chain. + intro k hk + -- Lane → touching step: 0,2→1; 1,3→2; 4,6→3; 5,7→4; 8,10→5; 9,11→6; + -- 12,14→7; 13,15→8. + interval_cases k + · -- k=0: touched at step 1. Unchanged in steps 2..8. + rw [h_v8_unc 0 hi0 (by decide) (by decide), + h_v7_unc 0 hi0 (by decide) (by decide), + h_v6_unc 0 hi0 (by decide) (by decide), + h_v5_unc 0 hi0 (by decide) (by decide), + h_v4_unc 0 hi0 (by decide) (by decide), + h_v3_unc 0 hi0 (by decide) (by decide), + h_v2_unc 0 hi0 (by decide) (by decide)] + exact h_v1_bnd_0 + · -- k=1: touched at step 2. + rw [h_v8_unc 1 hi1 (by decide) (by decide), + h_v7_unc 1 hi1 (by decide) (by decide), + h_v6_unc 1 hi1 (by decide) (by decide), + h_v5_unc 1 hi1 (by decide) (by decide), + h_v4_unc 1 hi1 (by decide) (by decide), + h_v3_unc 1 hi1 (by decide) (by decide)] + exact h_v2_bnd_1 + · -- k=2: touched at step 1. + rw [h_v8_unc 2 hi2 (by decide) (by decide), + h_v7_unc 2 hi2 (by decide) (by decide), + h_v6_unc 2 hi2 (by decide) (by decide), + h_v5_unc 2 hi2 (by decide) (by decide), + h_v4_unc 2 hi2 (by decide) (by decide), + h_v3_unc 2 hi2 (by decide) (by decide), + h_v2_unc 2 hi2 (by decide) (by decide)] + exact h_v1_bnd_2 + · -- k=3: touched at step 2. + rw [h_v8_unc 3 hi3 (by decide) (by decide), + h_v7_unc 3 hi3 (by decide) (by decide), + h_v6_unc 3 hi3 (by decide) (by decide), + h_v5_unc 3 hi3 (by decide) (by decide), + h_v4_unc 3 hi3 (by decide) (by decide), + h_v3_unc 3 hi3 (by decide) (by decide)] + exact h_v2_bnd_3 + · -- k=4: touched at step 3. + rw [h_v8_unc 4 hi4 (by decide) (by decide), + h_v7_unc 4 hi4 (by decide) (by decide), + h_v6_unc 4 hi4 (by decide) (by decide), + h_v5_unc 4 hi4 (by decide) (by decide), + h_v4_unc 4 hi4 (by decide) (by decide)] + exact h_v3_bnd_4 + · -- k=5: touched at step 4. + rw [h_v8_unc 5 hi5 (by decide) (by decide), + h_v7_unc 5 hi5 (by decide) (by decide), + h_v6_unc 5 hi5 (by decide) (by decide), + h_v5_unc 5 hi5 (by decide) (by decide)] + exact h_v4_bnd_5 + · -- k=6: touched at step 3. + rw [h_v8_unc 6 hi6 (by decide) (by decide), + h_v7_unc 6 hi6 (by decide) (by decide), + h_v6_unc 6 hi6 (by decide) (by decide), + h_v5_unc 6 hi6 (by decide) (by decide), + h_v4_unc 6 hi6 (by decide) (by decide)] + exact h_v3_bnd_6 + · -- k=7: touched at step 4. + rw [h_v8_unc 7 hi7 (by decide) (by decide), + h_v7_unc 7 hi7 (by decide) (by decide), + h_v6_unc 7 hi7 (by decide) (by decide), + h_v5_unc 7 hi7 (by decide) (by decide)] + exact h_v4_bnd_7 + · -- k=8: touched at step 5. + rw [h_v8_unc 8 hi8 (by decide) (by decide), + h_v7_unc 8 hi8 (by decide) (by decide), + h_v6_unc 8 hi8 (by decide) (by decide)] + exact h_v5_bnd_8 + · -- k=9: touched at step 6. + rw [h_v8_unc 9 hi9 (by decide) (by decide), + h_v7_unc 9 hi9 (by decide) (by decide)] + exact h_v6_bnd_9 + · -- k=10: touched at step 5. + rw [h_v8_unc 10 hi10 (by decide) (by decide), + h_v7_unc 10 hi10 (by decide) (by decide), + h_v6_unc 10 hi10 (by decide) (by decide)] + exact h_v5_bnd_10 + · -- k=11: touched at step 6. + rw [h_v8_unc 11 hi11 (by decide) (by decide), + h_v7_unc 11 hi11 (by decide) (by decide)] + exact h_v6_bnd_11 + · -- k=12: touched at step 7. + rw [h_v8_unc 12 hi12 (by decide) (by decide)] + exact h_v7_bnd_12 + · -- k=13: touched at step 8. + exact h_v8_bnd_13 + · -- k=14: touched at step 7. + rw [h_v8_unc 14 hi14 (by decide) (by decide)] + exact h_v7_bnd_14 + · -- k=15: touched at step 8. + exact h_v8_bnd_15 + +/-- L2.10 — `inv_ntt_layer_2_step`: 8 inverse butterfly pairs (0,4)…(3,7) + with z0 then (8,12)…(11,15) with z1. Mirror of `ntt_layer_2_step_fc` on the same lane-pair sequence. + + **Precondition adjustment** (beyond locked statement): + - `hvec : ∀ k < 16, |vec[k]| ≤ 13312` — preserved across 8 sequential + `inv_ntt_step_pair_fc` invocations on disjoint pairs (post-barrett + ≤ 3328, post-mont-mul ≤ 4993, untouched lanes preserve input). -/ +@[spec] +theorem inv_ntt_layer_2_step_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (z0 z1 : Std.I16) + (hz : z0.val.natAbs ≤ 1664 ∧ z1.val.natAbs ≤ 1664) + (hvec : ∀ k : Nat, k < 16 → + (vec.elements.val[k]!).val.natAbs ≤ 13312) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_2_step vec z0 z1 + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_inv_ntt_layer_2_step_pure (lift_chunk vec) + (lift_fe_mont z0) (lift_fe_mont z1) + ∧ (∀ k : Nat, k < 16 → + (r.elements.val[k]!).val.natAbs ≤ 3328) ⌝ ⦄ := by + obtain ⟨hz0, hz1⟩ := hz + have hi0 : (0 : Nat) < 16 := by decide + have hi1 : (1 : Nat) < 16 := by decide + have hi2 : (2 : Nat) < 16 := by decide + have hi3 : (3 : Nat) < 16 := by decide + have hi4 : (4 : Nat) < 16 := by decide + have hi5 : (5 : Nat) < 16 := by decide + have hi6 : (6 : Nat) < 16 := by decide + have hi7 : (7 : Nat) < 16 := by decide + have hi8 : (8 : Nat) < 16 := by decide + have hi9 : (9 : Nat) < 16 := by decide + have hi10 : (10 : Nat) < 16 := by decide + have hi11 : (11 : Nat) < 16 := by decide + have hi12 : (12 : Nat) < 16 := by decide + have hi13 : (13 : Nat) < 16 := by decide + have hi14 : (14 : Nat) < 16 := by decide + have hi15 : (15 : Nat) < 16 := by decide + -- Step 1: inv_ntt_step vec z0 0 4. + obtain ⟨v1, h_v1_eq, h_v1_lift, h_v1_unc, h_v1_bnd_0, h_v1_bnd_4⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc vec z0 0#usize 4#usize hi0 hi4 + (by decide) hz0 (hvec 0 hi0) (hvec 4 hi4)) + -- Step 2: inv_ntt_step v1 z0 1 5. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 13312 := by + rw [h_v1_unc 1 hi1 (by decide) (by decide)]; exact hvec 1 hi1 + have h_v1_5 : (v1.elements.val[5]!).val.natAbs ≤ 13312 := by + rw [h_v1_unc 5 hi5 (by decide) (by decide)]; exact hvec 5 hi5 + obtain ⟨v2, h_v2_eq, h_v2_lift, h_v2_unc, h_v2_bnd_1, h_v2_bnd_5⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v1 z0 1#usize 5#usize hi1 hi5 + (by decide) hz0 h_v1_1 h_v1_5) + -- Step 3: inv_ntt_step v2 z0 2 6. + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ 13312 := by + rw [h_v2_unc 2 hi2 (by decide) (by decide), + h_v1_unc 2 hi2 (by decide) (by decide)]; exact hvec 2 hi2 + have h_v2_6 : (v2.elements.val[6]!).val.natAbs ≤ 13312 := by + rw [h_v2_unc 6 hi6 (by decide) (by decide), + h_v1_unc 6 hi6 (by decide) (by decide)]; exact hvec 6 hi6 + obtain ⟨v3, h_v3_eq, h_v3_lift, h_v3_unc, h_v3_bnd_2, h_v3_bnd_6⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v2 z0 2#usize 6#usize hi2 hi6 + (by decide) hz0 h_v2_2 h_v2_6) + -- Step 4: inv_ntt_step v3 z0 3 7. + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ 13312 := by + rw [h_v3_unc 3 hi3 (by decide) (by decide), + h_v2_unc 3 hi3 (by decide) (by decide), + h_v1_unc 3 hi3 (by decide) (by decide)]; exact hvec 3 hi3 + have h_v3_7 : (v3.elements.val[7]!).val.natAbs ≤ 13312 := by + rw [h_v3_unc 7 hi7 (by decide) (by decide), + h_v2_unc 7 hi7 (by decide) (by decide), + h_v1_unc 7 hi7 (by decide) (by decide)]; exact hvec 7 hi7 + obtain ⟨v4, h_v4_eq, h_v4_lift, h_v4_unc, h_v4_bnd_3, h_v4_bnd_7⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v3 z0 3#usize 7#usize hi3 hi7 + (by decide) hz0 h_v3_3 h_v3_7) + -- Step 5: inv_ntt_step v4 z1 8 12. + have h_v4_8 : (v4.elements.val[8]!).val.natAbs ≤ 13312 := by + rw [h_v4_unc 8 hi8 (by decide) (by decide), + h_v3_unc 8 hi8 (by decide) (by decide), + h_v2_unc 8 hi8 (by decide) (by decide), + h_v1_unc 8 hi8 (by decide) (by decide)]; exact hvec 8 hi8 + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ 13312 := by + rw [h_v4_unc 12 hi12 (by decide) (by decide), + h_v3_unc 12 hi12 (by decide) (by decide), + h_v2_unc 12 hi12 (by decide) (by decide), + h_v1_unc 12 hi12 (by decide) (by decide)]; exact hvec 12 hi12 + obtain ⟨v5, h_v5_eq, h_v5_lift, h_v5_unc, h_v5_bnd_8, h_v5_bnd_12⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v4 z1 8#usize 12#usize hi8 hi12 + (by decide) hz1 h_v4_8 h_v4_12) + -- Step 6: inv_ntt_step v5 z1 9 13. + have h_v5_9 : (v5.elements.val[9]!).val.natAbs ≤ 13312 := by + rw [h_v5_unc 9 hi9 (by decide) (by decide), + h_v4_unc 9 hi9 (by decide) (by decide), + h_v3_unc 9 hi9 (by decide) (by decide), + h_v2_unc 9 hi9 (by decide) (by decide), + h_v1_unc 9 hi9 (by decide) (by decide)]; exact hvec 9 hi9 + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ 13312 := by + rw [h_v5_unc 13 hi13 (by decide) (by decide), + h_v4_unc 13 hi13 (by decide) (by decide), + h_v3_unc 13 hi13 (by decide) (by decide), + h_v2_unc 13 hi13 (by decide) (by decide), + h_v1_unc 13 hi13 (by decide) (by decide)]; exact hvec 13 hi13 + obtain ⟨v6, h_v6_eq, h_v6_lift, h_v6_unc, h_v6_bnd_9, h_v6_bnd_13⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v5 z1 9#usize 13#usize hi9 hi13 + (by decide) hz1 h_v5_9 h_v5_13) + -- Step 7: inv_ntt_step v6 z1 10 14. + have h_v6_10 : (v6.elements.val[10]!).val.natAbs ≤ 13312 := by + rw [h_v6_unc 10 hi10 (by decide) (by decide), + h_v5_unc 10 hi10 (by decide) (by decide), + h_v4_unc 10 hi10 (by decide) (by decide), + h_v3_unc 10 hi10 (by decide) (by decide), + h_v2_unc 10 hi10 (by decide) (by decide), + h_v1_unc 10 hi10 (by decide) (by decide)]; exact hvec 10 hi10 + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 13312 := by + rw [h_v6_unc 14 hi14 (by decide) (by decide), + h_v5_unc 14 hi14 (by decide) (by decide), + h_v4_unc 14 hi14 (by decide) (by decide), + h_v3_unc 14 hi14 (by decide) (by decide), + h_v2_unc 14 hi14 (by decide) (by decide), + h_v1_unc 14 hi14 (by decide) (by decide)]; exact hvec 14 hi14 + obtain ⟨v7, h_v7_eq, h_v7_lift, h_v7_unc, h_v7_bnd_10, h_v7_bnd_14⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v6 z1 10#usize 14#usize hi10 hi14 + (by decide) hz1 h_v6_10 h_v6_14) + -- Step 8: inv_ntt_step v7 z1 11 15. + have h_v7_11 : (v7.elements.val[11]!).val.natAbs ≤ 13312 := by + rw [h_v7_unc 11 hi11 (by decide) (by decide), + h_v6_unc 11 hi11 (by decide) (by decide), + h_v5_unc 11 hi11 (by decide) (by decide), + h_v4_unc 11 hi11 (by decide) (by decide), + h_v3_unc 11 hi11 (by decide) (by decide), + h_v2_unc 11 hi11 (by decide) (by decide), + h_v1_unc 11 hi11 (by decide) (by decide)]; exact hvec 11 hi11 + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 13312 := by + rw [h_v7_unc 15 hi15 (by decide) (by decide), + h_v6_unc 15 hi15 (by decide) (by decide), + h_v5_unc 15 hi15 (by decide) (by decide), + h_v4_unc 15 hi15 (by decide) (by decide), + h_v3_unc 15 hi15 (by decide) (by decide), + h_v2_unc 15 hi15 (by decide) (by decide), + h_v1_unc 15 hi15 (by decide) (by decide)]; exact hvec 15 hi15 + obtain ⟨v8, h_v8_eq, h_v8_lift, h_v8_unc, h_v8_bnd_11, h_v8_bnd_15⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v7 z1 11#usize 15#usize hi11 hi15 + (by decide) hz1 h_v7_11 h_v7_15) + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_2_step vec z0 z1 + = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_2_step + rw [h_v1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v5_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v6_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v7_eq]; simp only [Aeneas.Std.bind_tc_ok] + exact h_v8_eq + apply triple_of_ok_fc h_body + refine ⟨?_, ?_⟩ + · -- Spec equation. + unfold Spec.chunk_inv_ntt_layer_2_step_pure + rw [h_v8_lift, h_v7_lift, h_v6_lift, h_v5_lift, h_v4_lift, h_v3_lift, h_v2_lift, h_v1_lift] + · -- Per-lane output bound ≤ 3328. + -- Step touches: 1:(0,4), 2:(1,5), 3:(2,6), 4:(3,7), 5:(8,12), 6:(9,13), + -- 7:(10,14), 8:(11,15). All 16 lanes touched exactly once. + intro k hk + interval_cases k + · -- k=0: touched at step 1. + rw [h_v8_unc 0 hi0 (by decide) (by decide), + h_v7_unc 0 hi0 (by decide) (by decide), + h_v6_unc 0 hi0 (by decide) (by decide), + h_v5_unc 0 hi0 (by decide) (by decide), + h_v4_unc 0 hi0 (by decide) (by decide), + h_v3_unc 0 hi0 (by decide) (by decide), + h_v2_unc 0 hi0 (by decide) (by decide)] + exact h_v1_bnd_0 + · -- k=1: touched at step 2. + rw [h_v8_unc 1 hi1 (by decide) (by decide), + h_v7_unc 1 hi1 (by decide) (by decide), + h_v6_unc 1 hi1 (by decide) (by decide), + h_v5_unc 1 hi1 (by decide) (by decide), + h_v4_unc 1 hi1 (by decide) (by decide), + h_v3_unc 1 hi1 (by decide) (by decide)] + exact h_v2_bnd_1 + · -- k=2: touched at step 3. + rw [h_v8_unc 2 hi2 (by decide) (by decide), + h_v7_unc 2 hi2 (by decide) (by decide), + h_v6_unc 2 hi2 (by decide) (by decide), + h_v5_unc 2 hi2 (by decide) (by decide), + h_v4_unc 2 hi2 (by decide) (by decide)] + exact h_v3_bnd_2 + · -- k=3: touched at step 4. + rw [h_v8_unc 3 hi3 (by decide) (by decide), + h_v7_unc 3 hi3 (by decide) (by decide), + h_v6_unc 3 hi3 (by decide) (by decide), + h_v5_unc 3 hi3 (by decide) (by decide)] + exact h_v4_bnd_3 + · -- k=4: touched at step 1. + rw [h_v8_unc 4 hi4 (by decide) (by decide), + h_v7_unc 4 hi4 (by decide) (by decide), + h_v6_unc 4 hi4 (by decide) (by decide), + h_v5_unc 4 hi4 (by decide) (by decide), + h_v4_unc 4 hi4 (by decide) (by decide), + h_v3_unc 4 hi4 (by decide) (by decide), + h_v2_unc 4 hi4 (by decide) (by decide)] + exact h_v1_bnd_4 + · -- k=5: touched at step 2. + rw [h_v8_unc 5 hi5 (by decide) (by decide), + h_v7_unc 5 hi5 (by decide) (by decide), + h_v6_unc 5 hi5 (by decide) (by decide), + h_v5_unc 5 hi5 (by decide) (by decide), + h_v4_unc 5 hi5 (by decide) (by decide), + h_v3_unc 5 hi5 (by decide) (by decide)] + exact h_v2_bnd_5 + · -- k=6: touched at step 3. + rw [h_v8_unc 6 hi6 (by decide) (by decide), + h_v7_unc 6 hi6 (by decide) (by decide), + h_v6_unc 6 hi6 (by decide) (by decide), + h_v5_unc 6 hi6 (by decide) (by decide), + h_v4_unc 6 hi6 (by decide) (by decide)] + exact h_v3_bnd_6 + · -- k=7: touched at step 4. + rw [h_v8_unc 7 hi7 (by decide) (by decide), + h_v7_unc 7 hi7 (by decide) (by decide), + h_v6_unc 7 hi7 (by decide) (by decide), + h_v5_unc 7 hi7 (by decide) (by decide)] + exact h_v4_bnd_7 + · -- k=8: touched at step 5. + rw [h_v8_unc 8 hi8 (by decide) (by decide), + h_v7_unc 8 hi8 (by decide) (by decide), + h_v6_unc 8 hi8 (by decide) (by decide)] + exact h_v5_bnd_8 + · -- k=9: touched at step 6. + rw [h_v8_unc 9 hi9 (by decide) (by decide), + h_v7_unc 9 hi9 (by decide) (by decide)] + exact h_v6_bnd_9 + · -- k=10: touched at step 7. + rw [h_v8_unc 10 hi10 (by decide) (by decide)] + exact h_v7_bnd_10 + · -- k=11: touched at step 8. + exact h_v8_bnd_11 + · -- k=12: touched at step 5. + rw [h_v8_unc 12 hi12 (by decide) (by decide), + h_v7_unc 12 hi12 (by decide) (by decide), + h_v6_unc 12 hi12 (by decide) (by decide)] + exact h_v5_bnd_12 + · -- k=13: touched at step 6. + rw [h_v8_unc 13 hi13 (by decide) (by decide), + h_v7_unc 13 hi13 (by decide) (by decide)] + exact h_v6_bnd_13 + · -- k=14: touched at step 7. + rw [h_v8_unc 14 hi14 (by decide) (by decide)] + exact h_v7_bnd_14 + · -- k=15: touched at step 8. + exact h_v8_bnd_15 + +/-- L2.11 — `inv_ntt_layer_3_step`: 8 inverse butterfly pairs + (0,8)…(7,15) with one zeta. Mirror of `ntt_layer_3_step_fc` on the same lane-pair sequence. + + **Precondition adjustment** (beyond locked statement): + - `hvec : ∀ k < 16, |vec[k]| ≤ 13312` — chained through the 8 + inv_ntt_step calls. Pairs are disjoint (each lane touched exactly + once), so the per-lane bound holds at each step on the unchanged + lanes (touched lanes themselves stay ≤ 13312 by inv_ntt_step's + post-output bounds). -/ +@[spec] +theorem inv_ntt_layer_3_step_fc + (vec : libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (z : Std.I16) (hz : z.val.natAbs ≤ 1664) + (hvec : ∀ k : Nat, k < 16 → + (vec.elements.val[k]!).val.natAbs ≤ 13312) : + ⦃ ⌜ True ⌝ ⦄ + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_3_step vec z + ⦃ ⇓ r => ⌜ lift_chunk r + = Spec.chunk_inv_ntt_layer_3_step_pure (lift_chunk vec) (lift_fe_mont z) + ∧ (∀ k : Nat, k < 16 → + (r.elements.val[k]!).val.natAbs ≤ 3328) ⌝ ⦄ := by + have hi0 : (0 : Nat) < 16 := by decide + have hi1 : (1 : Nat) < 16 := by decide + have hi2 : (2 : Nat) < 16 := by decide + have hi3 : (3 : Nat) < 16 := by decide + have hi4 : (4 : Nat) < 16 := by decide + have hi5 : (5 : Nat) < 16 := by decide + have hi6 : (6 : Nat) < 16 := by decide + have hi7 : (7 : Nat) < 16 := by decide + have hi8 : (8 : Nat) < 16 := by decide + have hi9 : (9 : Nat) < 16 := by decide + have hi10 : (10 : Nat) < 16 := by decide + have hi11 : (11 : Nat) < 16 := by decide + have hi12 : (12 : Nat) < 16 := by decide + have hi13 : (13 : Nat) < 16 := by decide + have hi14 : (14 : Nat) < 16 := by decide + have hi15 : (15 : Nat) < 16 := by decide + -- Step 1: inv_ntt_step vec z 0 8. + obtain ⟨v1, h_v1_eq, h_v1_lift, h_v1_unc, h_v1_bnd_0, h_v1_bnd_8⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc vec z 0#usize 8#usize hi0 hi8 + (by decide) hz (hvec 0 hi0) (hvec 8 hi8)) + -- Step 2: inv_ntt_step v1 z 1 9. + have h_v1_1 : (v1.elements.val[1]!).val.natAbs ≤ 13312 := by + rw [h_v1_unc 1 hi1 (by decide) (by decide)]; exact hvec 1 hi1 + have h_v1_9 : (v1.elements.val[9]!).val.natAbs ≤ 13312 := by + rw [h_v1_unc 9 hi9 (by decide) (by decide)]; exact hvec 9 hi9 + obtain ⟨v2, h_v2_eq, h_v2_lift, h_v2_unc, h_v2_bnd_1, h_v2_bnd_9⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v1 z 1#usize 9#usize hi1 hi9 + (by decide) hz h_v1_1 h_v1_9) + -- Step 3: inv_ntt_step v2 z 2 10. + have h_v2_2 : (v2.elements.val[2]!).val.natAbs ≤ 13312 := by + rw [h_v2_unc 2 hi2 (by decide) (by decide), + h_v1_unc 2 hi2 (by decide) (by decide)]; exact hvec 2 hi2 + have h_v2_10 : (v2.elements.val[10]!).val.natAbs ≤ 13312 := by + rw [h_v2_unc 10 hi10 (by decide) (by decide), + h_v1_unc 10 hi10 (by decide) (by decide)]; exact hvec 10 hi10 + obtain ⟨v3, h_v3_eq, h_v3_lift, h_v3_unc, h_v3_bnd_2, h_v3_bnd_10⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v2 z 2#usize 10#usize hi2 hi10 + (by decide) hz h_v2_2 h_v2_10) + -- Step 4: inv_ntt_step v3 z 3 11. + have h_v3_3 : (v3.elements.val[3]!).val.natAbs ≤ 13312 := by + rw [h_v3_unc 3 hi3 (by decide) (by decide), + h_v2_unc 3 hi3 (by decide) (by decide), + h_v1_unc 3 hi3 (by decide) (by decide)]; exact hvec 3 hi3 + have h_v3_11 : (v3.elements.val[11]!).val.natAbs ≤ 13312 := by + rw [h_v3_unc 11 hi11 (by decide) (by decide), + h_v2_unc 11 hi11 (by decide) (by decide), + h_v1_unc 11 hi11 (by decide) (by decide)]; exact hvec 11 hi11 + obtain ⟨v4, h_v4_eq, h_v4_lift, h_v4_unc, h_v4_bnd_3, h_v4_bnd_11⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v3 z 3#usize 11#usize hi3 hi11 + (by decide) hz h_v3_3 h_v3_11) + -- Step 5: inv_ntt_step v4 z 4 12. + have h_v4_4 : (v4.elements.val[4]!).val.natAbs ≤ 13312 := by + rw [h_v4_unc 4 hi4 (by decide) (by decide), + h_v3_unc 4 hi4 (by decide) (by decide), + h_v2_unc 4 hi4 (by decide) (by decide), + h_v1_unc 4 hi4 (by decide) (by decide)]; exact hvec 4 hi4 + have h_v4_12 : (v4.elements.val[12]!).val.natAbs ≤ 13312 := by + rw [h_v4_unc 12 hi12 (by decide) (by decide), + h_v3_unc 12 hi12 (by decide) (by decide), + h_v2_unc 12 hi12 (by decide) (by decide), + h_v1_unc 12 hi12 (by decide) (by decide)]; exact hvec 12 hi12 + obtain ⟨v5, h_v5_eq, h_v5_lift, h_v5_unc, h_v5_bnd_4, h_v5_bnd_12⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v4 z 4#usize 12#usize hi4 hi12 + (by decide) hz h_v4_4 h_v4_12) + -- Step 6: inv_ntt_step v5 z 5 13. + have h_v5_5 : (v5.elements.val[5]!).val.natAbs ≤ 13312 := by + rw [h_v5_unc 5 hi5 (by decide) (by decide), + h_v4_unc 5 hi5 (by decide) (by decide), + h_v3_unc 5 hi5 (by decide) (by decide), + h_v2_unc 5 hi5 (by decide) (by decide), + h_v1_unc 5 hi5 (by decide) (by decide)]; exact hvec 5 hi5 + have h_v5_13 : (v5.elements.val[13]!).val.natAbs ≤ 13312 := by + rw [h_v5_unc 13 hi13 (by decide) (by decide), + h_v4_unc 13 hi13 (by decide) (by decide), + h_v3_unc 13 hi13 (by decide) (by decide), + h_v2_unc 13 hi13 (by decide) (by decide), + h_v1_unc 13 hi13 (by decide) (by decide)]; exact hvec 13 hi13 + obtain ⟨v6, h_v6_eq, h_v6_lift, h_v6_unc, h_v6_bnd_5, h_v6_bnd_13⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v5 z 5#usize 13#usize hi5 hi13 + (by decide) hz h_v5_5 h_v5_13) + -- Step 7: inv_ntt_step v6 z 6 14. + have h_v6_6 : (v6.elements.val[6]!).val.natAbs ≤ 13312 := by + rw [h_v6_unc 6 hi6 (by decide) (by decide), + h_v5_unc 6 hi6 (by decide) (by decide), + h_v4_unc 6 hi6 (by decide) (by decide), + h_v3_unc 6 hi6 (by decide) (by decide), + h_v2_unc 6 hi6 (by decide) (by decide), + h_v1_unc 6 hi6 (by decide) (by decide)]; exact hvec 6 hi6 + have h_v6_14 : (v6.elements.val[14]!).val.natAbs ≤ 13312 := by + rw [h_v6_unc 14 hi14 (by decide) (by decide), + h_v5_unc 14 hi14 (by decide) (by decide), + h_v4_unc 14 hi14 (by decide) (by decide), + h_v3_unc 14 hi14 (by decide) (by decide), + h_v2_unc 14 hi14 (by decide) (by decide), + h_v1_unc 14 hi14 (by decide) (by decide)]; exact hvec 14 hi14 + obtain ⟨v7, h_v7_eq, h_v7_lift, h_v7_unc, h_v7_bnd_6, h_v7_bnd_14⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v6 z 6#usize 14#usize hi6 hi14 + (by decide) hz h_v6_6 h_v6_14) + -- Step 8: inv_ntt_step v7 z 7 15. + have h_v7_7 : (v7.elements.val[7]!).val.natAbs ≤ 13312 := by + rw [h_v7_unc 7 hi7 (by decide) (by decide), + h_v6_unc 7 hi7 (by decide) (by decide), + h_v5_unc 7 hi7 (by decide) (by decide), + h_v4_unc 7 hi7 (by decide) (by decide), + h_v3_unc 7 hi7 (by decide) (by decide), + h_v2_unc 7 hi7 (by decide) (by decide), + h_v1_unc 7 hi7 (by decide) (by decide)]; exact hvec 7 hi7 + have h_v7_15 : (v7.elements.val[15]!).val.natAbs ≤ 13312 := by + rw [h_v7_unc 15 hi15 (by decide) (by decide), + h_v6_unc 15 hi15 (by decide) (by decide), + h_v5_unc 15 hi15 (by decide) (by decide), + h_v4_unc 15 hi15 (by decide) (by decide), + h_v3_unc 15 hi15 (by decide) (by decide), + h_v2_unc 15 hi15 (by decide) (by decide), + h_v1_unc 15 hi15 (by decide) (by decide)]; exact hvec 15 hi15 + obtain ⟨v8, h_v8_eq, h_v8_lift, h_v8_unc, h_v8_bnd_7, h_v8_bnd_15⟩ := + triple_exists_ok_fc (inv_ntt_step_pair_fc v7 z 7#usize 15#usize hi7 hi15 + (by decide) hz h_v7_7 h_v7_15) + have h_body : + libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_3_step vec z = .ok v8 := by + unfold libcrux_iot_ml_kem.vector.portable.ntt.inv_ntt_layer_3_step + rw [h_v1_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v2_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v3_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v4_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v5_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v6_eq]; simp only [Aeneas.Std.bind_tc_ok] + rw [h_v7_eq]; simp only [Aeneas.Std.bind_tc_ok] + exact h_v8_eq + apply triple_of_ok_fc h_body + refine ⟨?_, ?_⟩ + · -- Spec equation. + unfold Spec.chunk_inv_ntt_layer_3_step_pure + rw [h_v8_lift, h_v7_lift, h_v6_lift, h_v5_lift, h_v4_lift, h_v3_lift, h_v2_lift, h_v1_lift] + · -- Per-lane output bound ≤ 3328. + -- Step touches: 1:(0,8), 2:(1,9), 3:(2,10), 4:(3,11), 5:(4,12), 6:(5,13), + -- 7:(6,14), 8:(7,15). All 16 lanes touched exactly once. + intro k hk + interval_cases k + · -- k=0: touched at step 1. + rw [h_v8_unc 0 hi0 (by decide) (by decide), + h_v7_unc 0 hi0 (by decide) (by decide), + h_v6_unc 0 hi0 (by decide) (by decide), + h_v5_unc 0 hi0 (by decide) (by decide), + h_v4_unc 0 hi0 (by decide) (by decide), + h_v3_unc 0 hi0 (by decide) (by decide), + h_v2_unc 0 hi0 (by decide) (by decide)] + exact h_v1_bnd_0 + · -- k=1: touched at step 2. + rw [h_v8_unc 1 hi1 (by decide) (by decide), + h_v7_unc 1 hi1 (by decide) (by decide), + h_v6_unc 1 hi1 (by decide) (by decide), + h_v5_unc 1 hi1 (by decide) (by decide), + h_v4_unc 1 hi1 (by decide) (by decide), + h_v3_unc 1 hi1 (by decide) (by decide)] + exact h_v2_bnd_1 + · -- k=2: touched at step 3. + rw [h_v8_unc 2 hi2 (by decide) (by decide), + h_v7_unc 2 hi2 (by decide) (by decide), + h_v6_unc 2 hi2 (by decide) (by decide), + h_v5_unc 2 hi2 (by decide) (by decide), + h_v4_unc 2 hi2 (by decide) (by decide)] + exact h_v3_bnd_2 + · -- k=3: touched at step 4. + rw [h_v8_unc 3 hi3 (by decide) (by decide), + h_v7_unc 3 hi3 (by decide) (by decide), + h_v6_unc 3 hi3 (by decide) (by decide), + h_v5_unc 3 hi3 (by decide) (by decide)] + exact h_v4_bnd_3 + · -- k=4: touched at step 5. + rw [h_v8_unc 4 hi4 (by decide) (by decide), + h_v7_unc 4 hi4 (by decide) (by decide), + h_v6_unc 4 hi4 (by decide) (by decide)] + exact h_v5_bnd_4 + · -- k=5: touched at step 6. + rw [h_v8_unc 5 hi5 (by decide) (by decide), + h_v7_unc 5 hi5 (by decide) (by decide)] + exact h_v6_bnd_5 + · -- k=6: touched at step 7. + rw [h_v8_unc 6 hi6 (by decide) (by decide)] + exact h_v7_bnd_6 + · -- k=7: touched at step 8. + exact h_v8_bnd_7 + · -- k=8: touched at step 1. + rw [h_v8_unc 8 hi8 (by decide) (by decide), + h_v7_unc 8 hi8 (by decide) (by decide), + h_v6_unc 8 hi8 (by decide) (by decide), + h_v5_unc 8 hi8 (by decide) (by decide), + h_v4_unc 8 hi8 (by decide) (by decide), + h_v3_unc 8 hi8 (by decide) (by decide), + h_v2_unc 8 hi8 (by decide) (by decide)] + exact h_v1_bnd_8 + · -- k=9: touched at step 2. + rw [h_v8_unc 9 hi9 (by decide) (by decide), + h_v7_unc 9 hi9 (by decide) (by decide), + h_v6_unc 9 hi9 (by decide) (by decide), + h_v5_unc 9 hi9 (by decide) (by decide), + h_v4_unc 9 hi9 (by decide) (by decide), + h_v3_unc 9 hi9 (by decide) (by decide)] + exact h_v2_bnd_9 + · -- k=10: touched at step 3. + rw [h_v8_unc 10 hi10 (by decide) (by decide), + h_v7_unc 10 hi10 (by decide) (by decide), + h_v6_unc 10 hi10 (by decide) (by decide), + h_v5_unc 10 hi10 (by decide) (by decide), + h_v4_unc 10 hi10 (by decide) (by decide)] + exact h_v3_bnd_10 + · -- k=11: touched at step 4. + rw [h_v8_unc 11 hi11 (by decide) (by decide), + h_v7_unc 11 hi11 (by decide) (by decide), + h_v6_unc 11 hi11 (by decide) (by decide), + h_v5_unc 11 hi11 (by decide) (by decide)] + exact h_v4_bnd_11 + · -- k=12: touched at step 5. + rw [h_v8_unc 12 hi12 (by decide) (by decide), + h_v7_unc 12 hi12 (by decide) (by decide), + h_v6_unc 12 hi12 (by decide) (by decide)] + exact h_v5_bnd_12 + · -- k=13: touched at step 6. + rw [h_v8_unc 13 hi13 (by decide) (by decide), + h_v7_unc 13 hi13 (by decide) (by decide)] + exact h_v6_bnd_13 + · -- k=14: touched at step 7. + rw [h_v8_unc 14 hi14 (by decide) (by decide)] + exact h_v7_bnd_14 + · -- k=15: touched at step 8. + exact h_v8_bnd_15 + + + +/-- Chunk projection identity: `Spec.chunk_at (lift_poly re) k = lift_chunk re.coefs[k]`. + + Pointwise equality at each lane: `lift_fe ((re.coefs[k]).elems[j])` on both sides. -/ +theorem chunk_at_lift_poly_fc + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (k : Nat) (hk : k < 16) : + Spec.chunk_at (lift_poly re) k = lift_chunk (re.coefficients.val[k]!) := by + unfold Spec.chunk_at lift_poly lift_chunk + apply Subtype.ext + have h_chunk_len : (re.coefficients.val[k]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + -- Both sides .val unfold definitionally to a list; show list equality. + show (List.range 16).map + (fun j => ((List.range 256).map + (fun j' => lift_fe (re.coefficients.val[j' / 16]!).elements.val[j' % 16]!))[16 * k + j]!) + = (re.coefficients.val[k]!).elements.val.map lift_fe + apply List.ext_getElem + · simp + · intro i hi1 _hi2 + have hi : i < 16 := by + have : i < ((List.range 16).map _).length := hi1 + simpa using this + have h_idx_lt : 16 * k + i < 256 := by + have hk' : k ≤ 15 := by omega + have : 16 * k ≤ 16 * 15 := Nat.mul_le_mul_left _ hk' + omega + have h_list_len : ((List.range 256).map (fun j => + lift_fe ((re.coefficients.val[j / 16]!).elements.val[j % 16]!))).length = 256 := by + simp + have h_div : (16 * k + i) / 16 = k := by + have h_i_div : i / 16 = 0 := Nat.div_eq_zero_iff.mpr (Or.inr hi) + have := Nat.add_mul_div_left i k (by decide : 0 < 16) + omega + have h_mod : (16 * k + i) % 16 = i := by + have := Nat.add_mul_mod_self_left i k 16 + have h_i_mod : i % 16 = i := Nat.mod_eq_of_lt hi + omega + -- LHS chain: List.getElem_map (outer), List.getElem_range (outer), + -- getElem!_pos (big-list lookup), List.getElem_map (big-list), + -- List.getElem_range (big-list), h_div + h_mod. + rw [List.getElem_map, List.getElem_range, + getElem!_pos _ (16 * k + i) (by rw [h_list_len]; exact h_idx_lt), + List.getElem_map, List.getElem_range, h_div, h_mod] + -- Goal: lift_fe (val[i]!) = ((val).map lift_fe)[i]'_. + have h_getElem_pos : + (re.coefficients.val[k]!).elements.val[i]! + = (re.coefficients.val[k]!).elements.val[i]'(by rw [h_chunk_len]; exact hi) := + getElem!_pos (re.coefficients.val[k]!).elements.val i + (by rw [h_chunk_len]; exact hi) + rw [h_getElem_pos] + -- Goal: lift_fe (val[i]'_) = (val.map lift_fe)[i]'_. + exact (List.getElem_map (f := lift_fe) (l := (re.coefficients.val[k]!).elements.val)).symm + +/-- Flatten-chunks identity (the chunked image of a poly under `lift_poly`). + Mathematically: if `chunks[k] = lift_chunk re.coefs[k]` for all k < 16, + then `flatten_chunks chunks = lift_poly re`. + + Uses `getElem` (with proof) rather than `getElem!` in the hypothesis to + avoid a Lean elaborator issue with `Inhabited (Std.Array FE 16)` + synthesis in `∀` binders. -/ +theorem flatten_chunks_eq_lift_poly_fc + (re : libcrux_iot_ml_kem.polynomial.PolynomialRingElement + libcrux_iot_ml_kem.vector.portable.vector_type.PortableVector) + (chunks : Std.Array (Std.Array hacspec_ml_kem.parameters.FieldElement 16#usize) + 16#usize) + (h_chunks_len : chunks.val.length = 16) + (h_chunks : ∀ k : Nat, (hk : k < 16) → + chunks.val[k]'(by rw [h_chunks_len]; exact hk) + = lift_chunk (re.coefficients.val[k]!)) : + Spec.flatten_chunks chunks = lift_poly re := by + unfold Spec.flatten_chunks lift_poly + apply Subtype.ext + show (List.range 256).map (fun j => (chunks.val[j / 16]!).val[j % 16]!) + = (List.range 256).map (fun j => + lift_fe ((re.coefficients.val[j / 16]!).elements.val[j % 16]!)) + apply List.ext_getElem + · simp + · intro j hj1 _hj2 + have hj : j < 256 := by + have : j < ((List.range 256).map (fun j' => (chunks.val[j' / 16]!).val[j' % 16]!)).length := hj1 + simpa using this + have h_div_lt : j / 16 < 16 := Nat.div_lt_iff_lt_mul (by decide : 0 < 16) |>.mpr hj + have h_mod_lt : j % 16 < 16 := Nat.mod_lt _ (by decide : 0 < 16) + have h_chunk_len : (re.coefficients.val[j / 16]!).elements.val.length = 16 := + libcrux_iot_ml_kem.Vector.Portable.Arithmetic.LoopHelper.PortableVector_elements_length _ + have h_map_len : ((re.coefficients.val[j / 16]!).elements.val.map lift_fe).length = 16 := by + rw [List.length_map]; exact h_chunk_len + -- Pull `chunks.val[j/16]!` through `h_chunks` using `getElem!_pos`. + have h_chunks_getElem : + chunks.val[j / 16]! = lift_chunk (re.coefficients.val[j / 16]!) := by + rw [getElem!_pos chunks.val (j / 16) (by rw [h_chunks_len]; exact h_div_lt)] + exact h_chunks (j / 16) h_div_lt + rw [List.getElem_map, List.getElem_map, List.getElem_range, h_chunks_getElem] + -- Goal: (lift_chunk re.coefficients[j/16]).val[j % 16]! = lift_fe ((...).val[j % 16]!). + -- LHS unfolds to ((re.coefs[j/16]).elements.val.map lift_fe)[j % 16]!. + show ((re.coefficients.val[j / 16]!).elements.val.map lift_fe)[j % 16]! + = lift_fe ((re.coefficients.val[j / 16]!).elements.val[j % 16]!) + rw [getElem!_pos ((re.coefficients.val[j / 16]!).elements.val.map lift_fe) (j % 16) + (by rw [h_map_len]; exact h_mod_lt), + getElem!_pos (re.coefficients.val[j / 16]!).elements.val (j % 16) + (by rw [h_chunk_len]; exact h_mod_lt)] + exact List.getElem_map (f := lift_fe) + (l := (re.coefficients.val[j / 16]!).elements.val) + +end libcrux_iot_ml_kem.Vector.Portable.Ntt