Skip to content

Commit

Permalink
Create MLE eval component
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Aug 28, 2024
1 parent e67c1c5 commit 4953abd
Show file tree
Hide file tree
Showing 5 changed files with 543 additions and 100 deletions.
32 changes: 17 additions & 15 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct TraceLocationAllocator {
}

impl TraceLocationAllocator {
fn next_for_structure<T>(
pub fn next_for_structure<T>(
&mut self,
structure: &TreeVec<ColumnVec<T>>,
) -> TreeVec<TreeColumnSpan> {
Expand Down Expand Up @@ -75,14 +75,18 @@ pub struct FrameworkComponent<C: FrameworkEval> {
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self {
pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets;
let trace_locations = provider.next_for_structure(&eval_tree_structure);
let trace_locations = location_allocator.next_for_structure(&eval_tree_structure);
Self {
eval,
trace_locations,
}
}

pub fn trace_locations(&self) -> &[TreeColumnSpan] {
&self.trace_locations
}
}

impl<E: FrameworkEval> Component for FrameworkComponent<E> {
Expand All @@ -95,26 +99,20 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(
self.eval
.evaluate(InfoEvaluator::default())
.mask_offsets
.iter()
.map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()])
.collect(),
)
let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default());
mask_offsets.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()])
}

fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let info = self.eval.evaluate(InfoEvaluator::default());
let trace_step = CanonicCoset::new(self.eval.log_size()).step();
info.mask_offsets.map_cols(|col_mask| {
col_mask
let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default());
mask_offsets.map_cols(|col_offsets| {
col_offsets
.iter()
.map(|off| point + trace_step.mul_signed(*off).into_ef())
.map(|offset| point + trace_step.mul_signed(*offset).into_ef())
.collect()
})
}
Expand All @@ -139,6 +137,10 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
trace: &Trace<'_, SimdBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
) {
if self.n_constraints() == 0 {
return;
}

let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.eval.log_size());

Expand Down
12 changes: 11 additions & 1 deletion crates/prover/src/constraint_framework/constant_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@ use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};

/// Generates a column with a single one at the first position, and zeros elsewhere.
pub fn gen_is_first<B: Backend>(log_size: u32) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
gen_is_offset(log_size, 0)
}

/// Generates a column with a single one at the `offset`, and zeros elsewhere.
pub fn gen_is_offset<B: Backend>(
log_size: u32,
offset: isize,
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
let mut col = Col::<B, BaseField>::zeros(1 << log_size);
col.set(0, BaseField::one());
let offset = offset.rem_euclid(col.len() as isize) as usize;
let circle_domain_offset = coset_index_to_circle_domain_index(offset, log_size);
col.set(circle_domain_offset, BaseField::one());
CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
}

Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::core::utils::generate_secure_powers;
/// Accumulates N evaluations of u_i(P0) at a single point.
/// Computes f(P0), the combined polynomial at that point.
/// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i).
#[derive(Debug, Clone, Copy)]
pub struct PointEvaluationAccumulator {
random_coeff: SecureField,
accumulation: SecureField,
Expand Down
13 changes: 8 additions & 5 deletions crates/prover/src/examples/xor/gkr_lookups/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1;
/// IOP for multilinear eval at point.
pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR;

/// Accumulates [`Mle`]s grouped by their number of variables.
/// Collection of [`Mle`]s grouped by their number of variables.
pub struct MleCollection<B: Backend> {
mles_by_n_variables: Vec<Option<Vec<DynMle<B>>>>,
}

impl<B: Backend> MleCollection<B> {
/// Appends an [`Mle`] to the collection.
/// Appends an [`Mle`] to the back of the collection.
pub fn push(&mut self, mle: impl Into<DynMle<B>>) {
let mle = mle.into();
let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new());
Expand All @@ -35,6 +35,7 @@ impl<B: Backend> MleCollection<B> {
impl MleCollection<SimdBackend> {
/// Performs a random linear combination of all MLEs, grouped by their number of variables.
///
/// For `n` accumulated MLEs in a group, the `i`th MLE is multiplied by `alpha^(n-1-i)`.
/// MLEs are returned in ascending order by number of variables.
pub fn random_linear_combine_by_n_variables(
self,
Expand All @@ -53,13 +54,15 @@ impl MleCollection<SimdBackend> {
/// Panics if `mles` is empty or all MLEs don't have the same number of variables.
fn mle_random_linear_combination(
mles: Vec<DynMle<SimdBackend>>,
alpha: SecureField,
random_coeff: SecureField,
) -> Mle<SimdBackend, SecureField> {
assert!(!mles.is_empty());
let n_variables = mles[0].n_variables();
assert!(mles.iter().all(|mle| mle.n_variables() == n_variables));
let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev();
let mut mle_and_coeff = zip(mles, alpha_powers);
let coeff_powers = generate_secure_powers(random_coeff, mles.len())
.into_iter()
.rev();
let mut mle_and_coeff = zip(mles, coeff_powers);

// The last value can initialize the accumulator.
let (mle, coeff) = mle_and_coeff.next_back().unwrap();
Expand Down
Loading

0 comments on commit 4953abd

Please sign in to comment.