From 4953abd36f211cfdab0f2cbfa79a779a77221361 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Tue, 20 Aug 2024 23:36:04 -0400 Subject: [PATCH] Create MLE eval component --- .../src/constraint_framework/component.rs | 32 +- .../constraint_framework/constant_columns.rs | 12 +- crates/prover/src/core/air/accumulation.rs | 1 + .../examples/xor/gkr_lookups/accumulation.rs | 13 +- .../src/examples/xor/gkr_lookups/mle_eval.rs | 585 +++++++++++++++--- 5 files changed, 543 insertions(+), 100 deletions(-) diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 99cabc937..2603e7c12 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -31,7 +31,7 @@ pub struct TraceLocationAllocator { } impl TraceLocationAllocator { - fn next_for_structure( + pub fn next_for_structure( &mut self, structure: &TreeVec>, ) -> TreeVec { @@ -75,14 +75,18 @@ pub struct FrameworkComponent { } impl FrameworkComponent { - pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self { + pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self { let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets; - let trace_locations = provider.next_for_structure(&eval_tree_structure); + let trace_locations = location_allocator.next_for_structure(&eval_tree_structure); Self { eval, trace_locations, } } + + pub fn trace_locations(&self) -> &[TreeColumnSpan] { + &self.trace_locations + } } impl Component for FrameworkComponent { @@ -95,26 +99,20 @@ impl Component for FrameworkComponent { } fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new( - self.eval - .evaluate(InfoEvaluator::default()) - .mask_offsets - .iter() - .map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()]) - .collect(), - ) + let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default()); + mask_offsets.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()]) } fn mask_points( &self, point: CirclePoint, ) -> TreeVec>>> { - let info = self.eval.evaluate(InfoEvaluator::default()); let trace_step = CanonicCoset::new(self.eval.log_size()).step(); - info.mask_offsets.map_cols(|col_mask| { - col_mask + let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default()); + mask_offsets.map_cols(|col_offsets| { + col_offsets .iter() - .map(|off| point + trace_step.mul_signed(*off).into_ef()) + .map(|offset| point + trace_step.mul_signed(*offset).into_ef()) .collect() }) } @@ -139,6 +137,10 @@ impl ComponentProver for FrameworkComponent { trace: &Trace<'_, SimdBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, ) { + if self.n_constraints() == 0 { + return; + } + let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); let trace_domain = CanonicCoset::new(self.eval.log_size()); diff --git a/crates/prover/src/constraint_framework/constant_columns.rs b/crates/prover/src/constraint_framework/constant_columns.rs index e57df28ab..27ad6257e 100644 --- a/crates/prover/src/constraint_framework/constant_columns.rs +++ b/crates/prover/src/constraint_framework/constant_columns.rs @@ -8,8 +8,18 @@ use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; /// Generates a column with a single one at the first position, and zeros elsewhere. pub fn gen_is_first(log_size: u32) -> CircleEvaluation { + gen_is_offset(log_size, 0) +} + +/// Generates a column with a single one at the `offset`, and zeros elsewhere. +pub fn gen_is_offset( + log_size: u32, + offset: isize, +) -> CircleEvaluation { let mut col = Col::::zeros(1 << log_size); - col.set(0, BaseField::one()); + let offset = offset.rem_euclid(col.len() as isize) as usize; + let circle_domain_offset = coset_index_to_circle_domain_index(offset, log_size); + col.set(circle_domain_offset, BaseField::one()); CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) } diff --git a/crates/prover/src/core/air/accumulation.rs b/crates/prover/src/core/air/accumulation.rs index 8fcf57549..5ed55ac56 100644 --- a/crates/prover/src/core/air/accumulation.rs +++ b/crates/prover/src/core/air/accumulation.rs @@ -18,6 +18,7 @@ use crate::core::utils::generate_secure_powers; /// Accumulates N evaluations of u_i(P0) at a single point. /// Computes f(P0), the combined polynomial at that point. /// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i). +#[derive(Debug, Clone, Copy)] pub struct PointEvaluationAccumulator { random_coeff: SecureField, accumulation: SecureField, diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs index 53ae956e6..a63b62503 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs @@ -18,13 +18,13 @@ pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; /// IOP for multilinear eval at point. pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR; -/// Accumulates [`Mle`]s grouped by their number of variables. +/// Collection of [`Mle`]s grouped by their number of variables. pub struct MleCollection { mles_by_n_variables: Vec>>>, } impl MleCollection { - /// Appends an [`Mle`] to the collection. + /// Appends an [`Mle`] to the back of the collection. pub fn push(&mut self, mle: impl Into>) { let mle = mle.into(); let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new()); @@ -35,6 +35,7 @@ impl MleCollection { impl MleCollection { /// Performs a random linear combination of all MLEs, grouped by their number of variables. /// + /// For `n` accumulated MLEs in a group, the `i`th MLE is multiplied by `alpha^(n-1-i)`. /// MLEs are returned in ascending order by number of variables. pub fn random_linear_combine_by_n_variables( self, @@ -53,13 +54,15 @@ impl MleCollection { /// Panics if `mles` is empty or all MLEs don't have the same number of variables. fn mle_random_linear_combination( mles: Vec>, - alpha: SecureField, + random_coeff: SecureField, ) -> Mle { assert!(!mles.is_empty()); let n_variables = mles[0].n_variables(); assert!(mles.iter().all(|mle| mle.n_variables() == n_variables)); - let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev(); - let mut mle_and_coeff = zip(mles, alpha_powers); + let coeff_powers = generate_secure_powers(random_coeff, mles.len()) + .into_iter() + .rev(); + let mut mle_and_coeff = zip(mles, coeff_powers); // The last value can initialize the accumulator. let (mle, coeff) = mle_and_coeff.next_back().unwrap(); diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 72dc133a5..4def98c66 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -2,16 +2,25 @@ // TODO(andrew): Remove in downstream PR. #![allow(dead_code)] -use std::array; use std::iter::zip; use itertools::{chain, zip_eq, Itertools}; use num_traits::{One, Zero}; - -use crate::constraint_framework::EvalAtRow; -use crate::core::backend::simd::column::SecureColumn; +use tracing::{span, Level}; + +use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_offset}; +use crate::constraint_framework::{ + EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, TraceLocationAllocator, +}; +use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::backend::simd::column::{ + SecureColumn, VeryPackedBaseColumn, VeryPackedSecureColumnByCoords, +}; +use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column}; use crate::core::circle::{CirclePoint, Coset}; @@ -23,20 +32,283 @@ use crate::core::fields::{Field, FieldExpOps}; use crate::core::lookups::gkr_prover::GkrOps; use crate::core::lookups::mle::Mle; use crate::core::lookups::utils::eq; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; +use crate::core::pcs::{TreeColumnSpan, TreeVec}; +use crate::core::poly::circle::{ + CanonicCoset, CircleEvaluation, SecureCirclePoly, SecureEvaluation, +}; +use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; +use crate::core::utils::{self, bit_reverse_index, coset_index_to_circle_domain_index}; +use crate::core::ColumnVec; + +/// Component that carries out a univariate IOP for multilinear eval at point. +/// +/// See (Section 5.1). +#[allow(dead_code)] +pub struct MleEvalProverComponent<'twiddles, 'oracle, O: MleCoeffColumnOracle> { + /// Polynomials encoding the multilinear Lagrange basis coefficients of the MLE. + mle_coeff_column_poly: SecureCirclePoly, + /// Oracle for the polynomial encoding the multilinear Lagrange basis coefficients of the MLE. + /// + /// The oracle values should match `mle_coeff_column_poly` for any given evaluation point. The + /// polynomial is only stored directly to speed up constraint evaluation. The oracle is stored + /// to perform consistency checks with `mle_coeff_column_poly`. + mle_coeff_column_oracle: &'oracle O, + /// Multilinear evaluation point. + mle_eval_point: MleEvalPoint, + /// Equals `mle_claim / 2^mle_n_variables`. + mle_claim_shift: SecureField, + /// Commitment tree index for the trace. + interaction: usize, + /// Location in the trace for the this component. + trace_locations: TreeVec, + /// Precomputed twiddles tree. + twiddles: &'twiddles TwiddleTree, +} + +impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> MleEvalProverComponent<'twiddles, 'oracle, O> { + // TODO(andrew): Some eval points may affect completeness. Document. + pub fn generate( + location_allocator: &mut TraceLocationAllocator, + mle_coeff_column_oracle: &'oracle O, + mle_eval_point: &[SecureField], + mle: Mle, + mle_claim: SecureField, + twiddles: &'twiddles TwiddleTree, + interaction: usize, + ) -> Self { + #[cfg(test)] + assert_eq!(mle_claim, mle.eval_at_point(mle_eval_point)); + let n_variables = mle.n_variables(); + let mle_claim_shift = mle_claim / BaseField::from(1 << n_variables); + + let domain = CanonicCoset::new(n_variables as u32).circle_domain(); + let values = mle.into_evals().into_secure_column_by_coords(); + let mle_trace = SecureEvaluation::::new(domain, values); + let mle_coeff_column_poly = mle_trace.interpolate_with_twiddles(twiddles); + + let trace_structure = mle_eval_info(interaction, n_variables).mask_offsets; + let trace_locations = location_allocator.next_for_structure(&trace_structure); + + Self { + mle_coeff_column_poly, + mle_coeff_column_oracle, + mle_eval_point: MleEvalPoint::new(mle_eval_point), + mle_claim_shift, + interaction, + trace_locations, + twiddles, + } + } + + /// Size of this components trace columns. + pub fn log_size(&self) -> u32 { + self.mle_eval_point.n_variables() as u32 + } + + pub fn eval_info(&self) -> InfoEvaluator { + let n_variables = self.mle_eval_point.n_variables(); + mle_eval_info(self.interaction, n_variables) + } +} + +impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component + for MleEvalProverComponent<'twiddles, 'oracle, O> +{ + fn n_constraints(&self) -> usize { + self.eval_info().n_constraints + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + 1 + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + let log_size = self.log_size(); + let InfoEvaluator { mask_offsets, .. } = self.eval_info(); + mask_offsets.map(|tree_offsets| vec![log_size; tree_offsets.len()]) + } + + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + let trace_step = CanonicCoset::new(self.log_size()).step(); + let InfoEvaluator { mask_offsets, .. } = self.eval_info(); + mask_offsets.map_cols(|col_offsets| { + col_offsets + .iter() + .map(|offset| point + trace_step.mul_signed(*offset).into_ef()) + .collect() + }) + } + + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + accumulator: &mut PointEvaluationAccumulator, + ) { + // Consistency check the MLE coeffs column polynomial and oracle. + let mle_coeff_col_eval = self.mle_coeff_column_poly.eval_at_point(point); + let oracle_mle_coeff_col_eval = self.mle_coeff_column_oracle.evaluate_at_point(point, mask); + assert_eq!(mle_coeff_col_eval, oracle_mle_coeff_col_eval); + + let component_mask = mask.sub_tree(&self.trace_locations); + let trace_coset = CanonicCoset::new(self.log_size()).coset; + let vanish_on_trace_eval_inv = coset_vanishing(trace_coset, point).inverse(); + let mut eval = PointEvaluator::new(component_mask, accumulator, vanish_on_trace_eval_inv); + + let carry_quotients_col_eval = eval_carry_quotient_col(&self.mle_eval_point, point); + let is_first = eval_is_first(trace_coset, point); + let is_second = eval_is_first(trace_coset, point - trace_coset.step.into_ef()); + + eval_mle_eval_constraints( + self.interaction, + &mut eval, + mle_coeff_col_eval, + &self.mle_eval_point, + self.mle_claim_shift, + carry_quotients_col_eval, + is_first, + is_second, + ) + } +} + +impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver + for MleEvalProverComponent<'twiddles, 'oracle, O> +{ + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &Trace<'_, SimdBackend>, + accumulator: &mut DomainEvaluationAccumulator, + ) { + let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); + let trace_domain = CanonicCoset::new(self.log_size()); + + let component_trace = trace.evals.sub_tree(&self.trace_locations).map_cols(|c| *c); + + // Extend MLE coeffs column. + let span = span!(Level::INFO, "Extension").entered(); + let mle_coeffs_column_lde = VeryPackedSecureColumnByCoords::from( + self.mle_coeff_column_poly + .evaluate_with_twiddles(eval_domain, self.twiddles) + .values, + ); + let carry_quotients_column_lde = VeryPackedSecureColumnByCoords::from( + gen_carry_quotient_col(&self.mle_eval_point.p) + .interpolate_with_twiddles(self.twiddles) + .evaluate_with_twiddles(eval_domain, self.twiddles) + .values, + ); + let is_first_lde = VeryPackedBaseColumn::from( + gen_is_first::(self.log_size()) + .interpolate_with_twiddles(self.twiddles) + .evaluate_with_twiddles(eval_domain, self.twiddles) + .values, + ); + let is_second_lde = VeryPackedBaseColumn::from( + gen_is_offset::(self.log_size(), 1) + .interpolate_with_twiddles(self.twiddles) + .evaluate_with_twiddles(eval_domain, self.twiddles) + .values, + ); + span.exit(); + + // Denom inverses. + let log_expand = eval_domain.log_size() - trace_domain.log_size(); + let mut denom_inv = (0..1 << log_expand) + .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) + .collect_vec(); + utils::bit_reverse(&mut denom_inv); + + // Accumulator. + let [mut acc] = accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); + acc.random_coeff_powers.reverse(); + let acc_col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(acc.col) }; + + let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); + let n_very_packed_rows = + 1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS); + for vec_row in 0..n_very_packed_rows { + // Evaluate constrains at row. + let mut eval = SimdDomainEvaluator::new( + &component_trace, + vec_row, + &acc.random_coeff_powers, + trace_domain.log_size(), + eval_domain.log_size(), + ); + + let mle_coeffs_col_eval = unsafe { mle_coeffs_column_lde.packed_at(vec_row) }; + let carry_quotients_col_eval = unsafe { carry_quotients_column_lde.packed_at(vec_row) }; + let is_first = unsafe { *is_first_lde.data.get_unchecked(vec_row) }; + let is_second = unsafe { *is_second_lde.data.get_unchecked(vec_row) }; + eval_mle_eval_constraints( + self.interaction, + &mut eval, + mle_coeffs_col_eval, + &self.mle_eval_point, + self.mle_claim_shift, + carry_quotients_col_eval, + is_first, + is_second, + ); + + // Finalize row. + let row_res = eval.row_res; + let denom_inv = VeryPackedBaseField::broadcast( + denom_inv + [vec_row >> (trace_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)], + ); + unsafe { acc_col.set_packed(vec_row, acc_col.packed_at(vec_row) + row_res * denom_inv) } + } + } +} + +fn mle_eval_info(interaction: usize, n_variables: usize) -> InfoEvaluator { + let mut eval = InfoEvaluator::default(); + let mle_eval_point = MleEvalPoint::new(&vec![SecureField::from(2); n_variables]); + let mle_claim_shift = SecureField::zero(); + let mle_coeffs_col_eval = SecureField::zero(); + let carry_quotients_col_eval = SecureField::zero(); + let is_first = BaseField::zero(); + let is_second = BaseField::zero(); + eval_mle_eval_constraints( + interaction, + &mut eval, + mle_coeffs_col_eval, + &mle_eval_point, + mle_claim_shift, + carry_quotients_col_eval, + is_first, + is_second, + ); + eval +} + +/// Univariate polynomial oracle that encodes multilinear Lagrange basis coefficients of a MLE. +/// +/// The column should encode the MLE coefficients ordered on a circle domain. +pub trait MleCoeffColumnOracle { + fn evaluate_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField; +} /// Evaluates constraints that guarantee an MLE evaluates to a claim at a given point. /// /// `mle_coeffs_col_eval` should be the evaluation of the column containing the coefficients of the /// MLE in the multilinear Lagrange basis. `mle_claim_shift` should equal `claim / 2^N_VARIABLES`. #[allow(clippy::too_many_arguments)] -pub fn eval_mle_eval_constraints( +pub fn eval_mle_eval_constraints( interaction: usize, eval: &mut E, mle_coeffs_col_eval: E::EF, - mle_eval_point: MleEvalPoint, + mle_eval_point: &MleEvalPoint, mle_claim_shift: SecureField, carry_quotients_col_eval: E::EF, is_first: E::F, @@ -54,37 +326,49 @@ pub fn eval_mle_eval_constraints( eval_prefix_sum_constraints(interaction, eval, terms_col_eval, mle_claim_shift) } -#[derive(Debug, Clone, Copy)] -pub struct MleEvalPoint { +#[derive(Debug, Clone)] +pub struct MleEvalPoint { // Equals `eq({0}^|p|, p)`. eq_0_p: SecureField, // Equals `eq({1}^|p|, p)`. eq_1_p: SecureField, // Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`. - eq_carry_quotients: [SecureField; N_VARIABLES], + eq_carry_quotients: Vec, // Point `p`. - p: [SecureField; N_VARIABLES], + p: Vec, } -impl MleEvalPoint { +impl MleEvalPoint { /// Creates new metadata from point `p`. - pub fn new(p: [SecureField; N_VARIABLES]) -> Self { + /// + /// # Panics + /// + /// Panics if the point is empty. + pub fn new(p: &[SecureField]) -> Self { + assert!(!p.is_empty()); + let n_variables = p.len(); let zero = SecureField::zero(); let one = SecureField::one(); Self { - eq_0_p: eq(&[zero; N_VARIABLES], &p), - eq_1_p: eq(&[one; N_VARIABLES], &p), - eq_carry_quotients: array::from_fn(|i| { - let mut numer_assignment = vec![one; i + 1]; - numer_assignment[i] = zero; - let mut denom_assignment = vec![zero; i + 1]; - denom_assignment[i] = one; - eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) - }), - p, + eq_0_p: eq(&vec![zero; n_variables], p), + eq_1_p: eq(&vec![one; n_variables], p), + eq_carry_quotients: (0..n_variables) + .map(|i| { + let mut numer_assignment = vec![one; i + 1]; + numer_assignment[i] = zero; + let mut denom_assignment = vec![zero; i + 1]; + denom_assignment[i] = one; + eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) + }) + .collect(), + p: p.to_vec(), } } + + pub fn n_variables(&self) -> usize { + self.p.len() + } } /// Evaluates EqEvals constraints on a column. @@ -95,10 +379,10 @@ impl MleEvalPoint { /// evaluates constraints that guarantee: `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. /// /// See (Section 5.1). -fn eval_eq_constraints( +fn eval_eq_constraints( eq_interaction: usize, eval: &mut E, - mle_eval_point: MleEvalPoint, + mle_eval_point: &MleEvalPoint, carry_quotients_col_eval: E::EF, is_first: E::F, is_second: E::F, @@ -179,14 +463,15 @@ pub fn build_trace( /// `c(-C[i]) = c(-C[i + 1]) * q(-C[i])`. /// /// [`CircleDomain`]: crate::core::poly::circle::CircleDomain -fn gen_carry_quotient_col( - eval_point: &[SecureField; N_VARIABLES], +fn gen_carry_quotient_col( + eval_point: &[SecureField], ) -> SecureEvaluation { - let mle_eval_point = MleEvalPoint::new(*eval_point); + assert!(!eval_point.is_empty()); + let mle_eval_point = MleEvalPoint::new(eval_point); let (half_coset0_carry_quotients, half_coset1_carry_quotients) = gen_half_coset_carry_quotients(&mle_eval_point); - let log_size = N_VARIABLES as u32; + let log_size = mle_eval_point.n_variables() as u32; let size = 1 << log_size; let half_coset_size = size / 2; let mut col = SecureColumnByCoords::::zeros(size); @@ -216,11 +501,9 @@ fn gen_carry_quotient_col( // TODO(andrew): Optimize further. Inline `eval_step_selector` and get runtime down to // O(N_VARIABLES) vs current O(N_VARIABLES^2). Can also use vanishing evals to compute // half_coset0_last half_coset1_first. -fn eval_carry_quotient_col( - eval_point: &MleEvalPoint, - p: CirclePoint, -) -> SecureField { - let log_size = N_VARIABLES as u32; +fn eval_carry_quotient_col(eval_point: &MleEvalPoint, p: CirclePoint) -> SecureField { + let n_variables = eval_point.n_variables(); + let log_size = n_variables as u32; let coset = CanonicCoset::new(log_size).coset(); let (half_coset0_carry_quotients, half_coset1_carry_quotients) = @@ -228,7 +511,7 @@ fn eval_carry_quotient_col( let mut eval = SecureField::zero(); - for variable_i in 0..N_VARIABLES.saturating_sub(1) { + for variable_i in 0..n_variables.saturating_sub(1) { let log_step = variable_i as u32 + 2; let offset = (1 << (log_step - 1)) - 2; let half_coset0_selector = eval_step_selector_with_offset(coset, offset, log_step, p); @@ -292,14 +575,17 @@ fn eval_is_first(coset: Coset, p: CirclePoint) -> SecureField { } /// Output of the form: `(half_coset0_carry_quotients, half_coset1_carry_quotients)`. -fn gen_half_coset_carry_quotients( - eval_point: &MleEvalPoint, -) -> ([SecureField; N_VARIABLES], [SecureField; N_VARIABLES]) { +fn gen_half_coset_carry_quotients( + eval_point: &MleEvalPoint, +) -> (Vec, Vec) { let last_variable = *eval_point.p.last().unwrap(); - let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients; + let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients.clone(); *half_coset0_carry_quotients.last_mut().unwrap() *= eq(&[SecureField::one()], &[last_variable]) / eq(&[SecureField::zero()], &[last_variable]); - let half_coset1_carry_quotients = half_coset0_carry_quotients.map(|v| v.inverse()); + let half_coset1_carry_quotients = half_coset0_carry_quotients + .iter() + .map(|v| v.inverse()) + .collect(); (half_coset0_carry_quotients, half_coset1_carry_quotients) } @@ -321,6 +607,7 @@ mod tests { use std::iter::{repeat, zip}; use itertools::{chain, Itertools}; + use mle_coeff_column::{MleCoeffColumnComponent, MleCoeffColumnEval}; use num_traits::One; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -330,23 +617,90 @@ mod tests { eval_prefix_sum_constraints, gen_carry_quotient_col, MleEvalPoint, }; use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset}; - use crate::constraint_framework::{assert_constraints, EvalAtRow}; + use crate::constraint_framework::{assert_constraints, EvalAtRow, TraceLocationAllocator}; + use crate::core::air::{Component, ComponentProver, Components}; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; + use crate::core::channel::Blake2sChannel; use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::lookups::mle::Mle; - use crate::core::pcs::TreeVec; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps, SecureEvaluation}; + use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; + use crate::core::prover::{prove, verify, VerificationError}; use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; use crate::examples::xor::gkr_lookups::mle_eval::{ - build_trace, eval_step_selector_with_offset, + build_trace, eval_step_selector_with_offset, MleEvalProverComponent, }; + #[test] + fn mle_eval_prover_component() -> Result<(), VerificationError> { + const N_VARIABLES: usize = 8; + const COEFFS_COL_TRACE: usize = 0; + const MLE_EVAL_TRACE: usize = 1; + const LOG_EXPAND: u32 = 1; + // Create the test MLE. + let mut rng = SmallRng::seed_from_u64(0); + let log_size = N_VARIABLES as u32; + let size = 1 << log_size; + let mle_coeffs = (0..size).map(|_| rng.gen::()).collect(); + let mle = Mle::::new(mle_coeffs); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let claim = mle.eval_at_point(&eval_point); + // Setup protocol. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_size + LOG_EXPAND + MIN_LOG_BLOWUP_FACTOR) + .circle_domain() + .half_coset, + ); + let config = PcsConfig::default(); + let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); + let channel = &mut Blake2sChannel::default(); + // Build trace. + // 1. MLE coeffs trace. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(mle_coeff_column::build_trace(&mle)); + tree_builder.commit(channel); + // 2. MLE eval trace (eq evals + prefix sum). + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(build_trace(&mle, &eval_point, claim)); + tree_builder.commit(channel); + // Create components. + let trace_location_allocator = &mut TraceLocationAllocator::default(); + let mle_coeffs_col_component = MleCoeffColumnComponent::new( + trace_location_allocator, + MleCoeffColumnEval::new(COEFFS_COL_TRACE, mle.n_variables()), + ); + let mle_eval_component = MleEvalProverComponent::generate( + trace_location_allocator, + &mle_coeffs_col_component, + &eval_point, + mle, + claim, + &twiddles, + MLE_EVAL_TRACE, + ); + let components: &[&dyn ComponentProver] = + &[&mle_coeffs_col_component, &mle_eval_component]; + // Generate proof. + let proof = prove(components, channel, commitment_scheme).unwrap(); + + // Verify. + let components = Components(components.iter().map(|&c| c as &dyn Component).collect()); + let log_sizes = components.column_log_sizes(); + let channel = &mut Blake2sChannel::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + commitment_scheme.commit(proof.commitments[0], &log_sizes[0], channel); + commitment_scheme.commit(proof.commitments[1], &log_sizes[1], channel); + verify(&components.0, channel, commitment_scheme, proof) + } + #[test] fn test_mle_eval_constraints_with_log_size_5() { const N_VARIABLES: usize = 5; @@ -360,9 +714,9 @@ mod tests { let mle = Mle::::new(mle_coeffs); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); let claim = mle.eval_at_point(&eval_point); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let mle_eval_trace = build_trace(&mle, &eval_point, claim); - let mle_coeffs_col_trace = build_mle_coeffs_trace(mle); + let mle_coeffs_col_trace = mle_coeff_column::build_trace(&mle); let claim_shift = claim / BaseField::from(size); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); let is_first_col = [gen_is_first(log_size)]; @@ -379,7 +733,7 @@ mod tests { MLE_EVAL_TRACE, &mut eval, mle_coeff_col_eval, - mle_eval_point, + &mle_eval_point, claim_shift, carry_quotients_col_eval, is_first_eval, @@ -397,14 +751,14 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); let is_first_col = [gen_is_first(N_VARIABLES as u32)]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); @@ -412,7 +766,7 @@ mod tests { eval_eq_constraints( EQ_EVAL_TRACE, &mut eval, - mle_eval_point, + &mle_eval_point, carry_quotients_col_eval, is_first, is_second, @@ -428,14 +782,14 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); let is_first_col = [gen_is_first(N_VARIABLES as u32)]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); @@ -443,7 +797,7 @@ mod tests { eval_eq_constraints( EQ_EVAL_TRACE, &mut eval, - mle_eval_point, + &mle_eval_point, carry_quotients_col_eval, is_first, is_second, @@ -459,14 +813,14 @@ mod tests { let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); let is_first_col = [gen_is_first(N_VARIABLES as u32)]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let trace_domain = CanonicCoset::new(N_VARIABLES as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); @@ -474,7 +828,7 @@ mod tests { eval_eq_constraints( EQ_EVAL_TRACE, &mut eval, - mle_eval_point, + &mle_eval_point, carry_quotients_col_eval, is_first, is_second, @@ -519,7 +873,7 @@ mod tests { const N_VARIABLES: usize = 5; let mut rng = SmallRng::seed_from_u64(0); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_point = MleEvalPoint::new(&eval_point); let col_eval = gen_carry_quotient_col(&eval_point); let twiddles = SimdBackend::precompute_twiddles(col_eval.domain.half_coset); let col_poly = col_eval.interpolate_with_twiddles(&twiddles); @@ -570,26 +924,99 @@ mod tests { .collect() } - /// Generates a trace. - /// - /// Trace structure: - /// - /// ```text - /// ----------------------------- - /// | MLE coeffs col | - /// ----------------------------- - /// | c0 | c1 | c2 | c3 | - /// ----------------------------- - /// ``` - fn build_mle_coeffs_trace( - mle: Mle, - ) -> Vec> { - let log_size = mle.n_variables() as u32; - let trace_domain = CanonicCoset::new(log_size).circle_domain(); - let mle_coeffs_col_by_coords = mle.into_evals().into_secure_column_by_coords(); - SecureEvaluation::new(trace_domain, mle_coeffs_col_by_coords) - .into_coordinate_evals() - .into_iter() - .collect() + mod mle_coeff_column { + use num_traits::One; + + use crate::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator, + }; + use crate::core::air::accumulation::PointEvaluationAccumulator; + use crate::core::backend::simd::SimdBackend; + use crate::core::circle::CirclePoint; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::mle::Mle; + use crate::core::pcs::TreeVec; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::core::ColumnVec; + use crate::examples::xor::gkr_lookups::mle_eval::MleCoeffColumnOracle; + + pub type MleCoeffColumnComponent = FrameworkComponent; + + pub struct MleCoeffColumnEval { + interaction: usize, + n_variables: usize, + } + + impl MleCoeffColumnEval { + pub fn new(interaction: usize, n_variables: usize) -> Self { + Self { + interaction, + n_variables, + } + } + } + + impl FrameworkEval for MleCoeffColumnEval { + fn log_size(&self) -> u32 { + self.n_variables as u32 + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + } + + fn evaluate(&self, mut eval: E) -> E { + let _ = eval_mle_coeff_col(self.interaction, &mut eval); + eval + } + } + + impl MleCoeffColumnOracle for MleCoeffColumnComponent { + fn evaluate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField { + // Create dummy point evaluator just to extract the value we need from the mask + let mut accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let mut eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut accumulator, + SecureField::one(), + ); + + eval_mle_coeff_col(self.interaction, &mut eval) + } + } + + fn eval_mle_coeff_col(interaction: usize, eval: &mut E) -> E::EF { + let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(interaction, [0]); + mle_coeff_col_eval + } + + /// Generates a trace. + /// + /// Trace structure: + /// + /// ```text + /// ----------------------------- + /// | MLE coeffs col | + /// ----------------------------- + /// | c0 | c1 | c2 | c3 | + /// ----------------------------- + /// ``` + pub fn build_trace( + mle: &Mle, + ) -> Vec> { + let log_size = mle.n_variables() as u32; + let trace_domain = CanonicCoset::new(log_size).circle_domain(); + let mle_coeffs_col_by_coords = mle.clone().into_evals().into_secure_column_by_coords(); + SecureEvaluation::new(trace_domain, mle_coeffs_col_by_coords) + .into_coordinate_evals() + .into_iter() + .collect() + } } }