From fb1df16064e441c1dfd0e27e5f04772f8e1538fe Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Sun, 25 Aug 2024 13:03:34 -0400 Subject: [PATCH] Create blake component that uses GKR for lookups --- .../prover/src/constraint_framework/logup.rs | 2 +- .../prover/src/core/lookups/gkr_verifier.rs | 6 + crates/prover/src/examples/blake/air.rs | 31 +-- crates/prover/src/examples/blake/mod.rs | 44 +-- crates/prover/src/examples/blake/round/gen.rs | 4 +- crates/prover/src/examples/blake/round/mod.rs | 2 +- .../src/examples/blake/scheduler/gen.rs | 8 +- .../src/examples/blake/scheduler/mod.rs | 2 +- .../src/examples/blake/xor_table/gen.rs | 2 +- .../src/examples/blake/xor_table/mod.rs | 6 +- crates/prover/src/examples/blake_gkr/air.rs | 253 ++++++++++++++++++ .../gkr_lookups/accumulation.rs | 2 +- .../gkr_lookups/mle_eval.rs | 9 +- .../{xor => blake_gkr}/gkr_lookups/mod.rs | 0 crates/prover/src/examples/blake_gkr/mod.rs | 5 + crates/prover/src/examples/blake_gkr/round.rs | 49 ++++ .../src/examples/blake_gkr/scheduler.rs | 56 ++++ .../src/examples/blake_gkr/xor_table.rs | 65 +++++ crates/prover/src/examples/mod.rs | 2 +- crates/prover/src/examples/xor/mod.rs | 1 - 20 files changed, 489 insertions(+), 60 deletions(-) create mode 100644 crates/prover/src/examples/blake_gkr/air.rs rename crates/prover/src/examples/{xor => blake_gkr}/gkr_lookups/accumulation.rs (98%) rename crates/prover/src/examples/{xor => blake_gkr}/gkr_lookups/mle_eval.rs (99%) rename crates/prover/src/examples/{xor => blake_gkr}/gkr_lookups/mod.rs (100%) create mode 100644 crates/prover/src/examples/blake_gkr/mod.rs create mode 100644 crates/prover/src/examples/blake_gkr/round.rs create mode 100644 crates/prover/src/examples/blake_gkr/scheduler.rs create mode 100644 crates/prover/src/examples/blake_gkr/xor_table.rs delete mode 100644 crates/prover/src/examples/xor/mod.rs diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index b9c7b0866..1aeef7730 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -140,7 +140,7 @@ impl LookupElements { } pub fn combine(&self, values: &[F]) -> EF where - EF: Copy + Zero + From + From + Mul + Sub, + EF: Copy + Zero + From + From + Mul + Sub, { EF::from(values[0]) + values[1..] diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index b65ceb162..7fdbab4c9 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -168,6 +168,12 @@ pub struct GkrArtifact { pub n_variables_by_instance: Vec, } +impl GkrArtifact { + pub fn ood_point(&self, instance_n_variables: usize) -> &[SecureField] { + &self.ood_point[self.ood_point.len() - instance_n_variables..] + } +} + /// Defines how a circuit operates locally on two input rows to produce a single output row. /// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure. /// diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index 58bb8633b..651f44753 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -61,9 +61,9 @@ impl BlakeStatement0 { } pub struct AllElements { - blake_elements: BlakeElements, - round_elements: RoundElements, - xor_elements: BlakeXorElements, + pub blake_elements: BlakeElements, + pub round_elements: RoundElements, + pub xor_elements: BlakeXorElements, } impl AllElements { pub fn draw(channel: &mut impl Channel) -> Self { @@ -223,7 +223,7 @@ where { assert!(log_size >= LOG_N_LANES); assert_eq!( - ROUND_LOG_SPLIT.map(|x| (1 << x)).into_iter().sum::() as usize, + ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::(), N_ROUNDS ); @@ -240,7 +240,7 @@ where span.exit(); // Prepare inputs. - let blake_inputs = (0..(1 << (log_size - LOG_N_LANES))) + let blake_inputs = (0..1 << (log_size - LOG_N_LANES)) .map(|i| { let v = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j) as u32)); 16]; let m = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; @@ -282,18 +282,15 @@ where // Trace commitment. let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( - chain![ - scheduler_trace, - round_traces.into_iter().flatten(), - xor_trace12, - xor_trace9, - xor_trace8, - xor_trace7, - xor_trace4, - ] - .collect_vec(), - ); + tree_builder.extend_evals(chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ]); tree_builder.commit(channel); span.exit(); diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index 6fbe6d81b..686f4733b 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -12,28 +12,28 @@ use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::FieldExpOps; -mod air; -mod round; -mod scheduler; -mod xor_table; +pub mod air; +pub mod round; +pub mod scheduler; +pub mod xor_table; -const STATE_SIZE: usize = 16; -const MESSAGE_SIZE: usize = 16; -const N_FELTS_IN_U32: usize = 2; -const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; +pub const STATE_SIZE: usize = 16; +pub const MESSAGE_SIZE: usize = 16; +pub const N_FELTS_IN_U32: usize = 2; +pub const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; // Parameters for Blake2s. Change these for blake3. -const N_ROUNDS: usize = 10; +pub const N_ROUNDS: usize = 10; /// A splitting N_ROUNDS into several powers of 2. -const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; +pub const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; #[derive(Default)] -struct XorAccums { - xor12: XorAccumulator<12, 4>, - xor9: XorAccumulator<9, 2>, - xor8: XorAccumulator<8, 2>, - xor7: XorAccumulator<7, 2>, - xor4: XorAccumulator<4, 0>, +pub struct XorAccums { + pub xor12: XorAccumulator<12, 4>, + pub xor9: XorAccumulator<9, 2>, + pub xor8: XorAccumulator<8, 2>, + pub xor7: XorAccumulator<7, 2>, + pub xor4: XorAccumulator<4, 0>, } impl XorAccums { fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) { @@ -50,11 +50,11 @@ impl XorAccums { #[derive(Clone)] pub struct BlakeXorElements { - xor12: XorElements, - xor9: XorElements, - xor8: XorElements, - xor7: XorElements, - xor4: XorElements, + pub xor12: XorElements, + pub xor9: XorElements, + pub xor8: XorElements, + pub xor7: XorElements, + pub xor4: XorElements, } impl BlakeXorElements { fn draw(channel: &mut impl Channel) -> Self { @@ -75,7 +75,7 @@ impl BlakeXorElements { xor4: XorElements::dummy(), } } - fn get(&self, w: u32) -> &XorElements { + pub fn get(&self, w: u32) -> &XorElements { match w { 12 => &self.xor12, 9 => &self.xor9, diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs index 9bddcdd40..f6adf58c2 100644 --- a/crates/prover/src/examples/blake/round/gen.rs +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -23,9 +23,9 @@ use crate::examples::blake::{to_felts, XorAccums, N_ROUND_INPUT_FELTS, STATE_SIZ pub struct BlakeRoundLookupData { /// A vector of (w, [a_col, b_col, c_col]) for each xor lookup. /// w is the xor width. c_col is the xor col of a_col and b_col. - xor_lookups: Vec<(u32, [BaseColumn; 3])>, + pub xor_lookups: Vec<(u32, [BaseColumn; 3])>, /// A column of round lookup values (v_in, v_out, m). - round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], + pub round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], } pub struct TraceGenerator { diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index cf8311339..cca09afef 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -1,7 +1,7 @@ mod constraints; mod gen; -pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; +pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput, BlakeRoundLookupData}; use num_traits::Zero; use super::{BlakeXorElements, N_ROUND_INPUT_FELTS}; diff --git a/crates/prover/src/examples/blake/scheduler/gen.rs b/crates/prover/src/examples/blake/scheduler/gen.rs index cd6a99b2f..ae3569ed0 100644 --- a/crates/prover/src/examples/blake/scheduler/gen.rs +++ b/crates/prover/src/examples/blake/scheduler/gen.rs @@ -58,7 +58,7 @@ pub fn gen_trace( .map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) .collect_vec(); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let mut col_index = 0; let mut write_u32_array = |x: [u32x16; STATE_SIZE], col_index: &mut usize| { @@ -125,11 +125,11 @@ pub fn gen_interaction_trace( let mut logup_gen = LogupTraceGenerator::new(log_size); - for [l0, l1] in lookup_data.round_lookups.array_chunks::<2>() { + for [l0, l1] in lookup_data.round_lookups.array_chunks() { let mut col_gen = logup_gen.new_col(); #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let p0: PackedSecureField = round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); let p1: PackedSecureField = @@ -145,7 +145,7 @@ pub fn gen_interaction_trace( // with the entire blake lookup. let mut col_gen = logup_gen.new_col(); #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let p_blake: PackedSecureField = blake_lookup_elements.combine( &lookup_data .blake_lookups diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index e8a8c32f3..d758d41de 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -2,7 +2,7 @@ mod constraints; mod gen; use constraints::eval_blake_scheduler_constraints; -pub use gen::{gen_interaction_trace, gen_trace, BlakeInput}; +pub use gen::{gen_interaction_trace, gen_trace, BlakeInput, BlakeSchedulerLookupData}; use num_traits::Zero; use super::round::RoundElements; diff --git a/crates/prover/src/examples/blake/xor_table/gen.rs b/crates/prover/src/examples/blake/xor_table/gen.rs index 195a6ca46..46309e640 100644 --- a/crates/prover/src/examples/blake/xor_table/gen.rs +++ b/crates/prover/src/examples/blake/xor_table/gen.rs @@ -74,7 +74,7 @@ pub fn generate_interaction_trace( // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (column_bits::() - LOG_N_LANES)) { + for vec_row in 0..1 << (column_bits::() - LOG_N_LANES) { // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. // Extract al, blh from vec_row. let al = vec_row >> (limb_bits - LOG_N_LANES); diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index bd74ea040..9f344b18d 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -17,7 +17,9 @@ use std::simd::u32x16; use itertools::Itertools; use num_traits::Zero; -pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace}; +pub use r#gen::{ + generate_constant_trace, generate_interaction_trace, generate_trace, XorTableLookupData, +}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; @@ -37,7 +39,7 @@ pub fn trace_sizes() -> TreeVec()) } -const fn limb_bits() -> u32 { +pub const fn limb_bits() -> u32 { ELEM_BITS - EXPAND_BITS } pub const fn column_bits() -> u32 { diff --git a/crates/prover/src/examples/blake_gkr/air.rs b/crates/prover/src/examples/blake_gkr/air.rs new file mode 100644 index 000000000..3ad707a37 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/air.rs @@ -0,0 +1,253 @@ +use std::array; +use std::collections::BTreeMap; +use std::iter::zip; +use std::simd::u32x16; + +use itertools::{chain, multiunzip, Itertools}; +use tracing::{span, Level}; + +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::BackendForChannel; +use crate::core::channel::{Channel, MerkleChannel}; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::prove_batch; +use crate::core::lookups::gkr_verifier::{GkrArtifact, GkrBatchProof}; +use crate::core::pcs::{CommitmentSchemeProver, PcsConfig}; +use crate::core::poly::circle::{CanonicCoset, PolyOps}; +use crate::core::prover::StarkProof; +use crate::core::vcs::ops::MerkleHasher; +use crate::examples::blake::air::AllElements; +use crate::examples::blake::scheduler::{self as air_scheduler, BlakeInput}; +use crate::examples::blake::{ + round as air_round, xor_table as air_xor_table, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT, +}; +use crate::examples::blake_gkr::gkr_lookups::accumulation::MleCollection; +use crate::examples::blake_gkr::{round, scheduler, xor_table}; + +pub struct BlakeClaim { + log_size: u32, +} + +impl BlakeClaim { + fn mix_into(&self, channel: &mut impl Channel) { + // TODO(spapini): Do this better. + channel.mix_nonce(self.log_size as u64); + } +} + +pub struct BlakeProof { + pub claim: BlakeClaim, + pub gkr_proof: GkrBatchProof, + pub stark_proof: StarkProof, +} + +pub struct BlakeComponents { + // scheduler_component: BlakeSchedulerComponent, + // round_components: Vec, + // xor12: XorTableComponent<12, 4>, + // xor9: XorTableComponent<9, 2>, + // xor8: XorTableComponent<8, 2>, + // xor7: XorTableComponent<7, 2>, + // xor4: XorTableComponent<4, 2>, +} + +pub fn prove_blake(log_size: u32, config: PcsConfig) -> BlakeProof +where + SimdBackend: BackendForChannel, +{ + assert!(log_size >= LOG_N_LANES); + assert_eq!( + ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::(), + N_ROUNDS + ); + + // Precompute twiddles. + let span = span!(Level::INFO, "Precompute twiddles").entered(); + const XOR_TABLE_MAX_LOG_SIZE: u32 = 16; + let max_log_size = + (log_size + *ROUND_LOG_SPLIT.iter().max().unwrap()).max(XOR_TABLE_MAX_LOG_SIZE); + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(max_log_size + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + span.exit(); + + // Prepare inputs. + let blake_inputs = (0..1 << (log_size - LOG_N_LANES)) + .map(|i| { + let v = [u32x16::from_array(array::from_fn(|j| (i + 2 * j) as u32)); 16]; + let m = [u32x16::from_array(array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; + BlakeInput { v, m } + }) + .collect_vec(); + + // Setup protocol. + let channel = &mut MC::C::default(); + let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); + + let span = span!(Level::INFO, "Trace").entered(); + + // Scheduler. + let (scheduler_trace, scheduler_lookup_data, round_inputs) = + air_scheduler::gen_trace(log_size, &blake_inputs); + + // Rounds. + let mut xor_accums = XorAccums::default(); + let mut rest = &round_inputs[..]; + // Split round inputs to components, according to [ROUND_LOG_SPLIT]. + let (round_traces, round_lookup_datas): (Vec<_>, Vec<_>) = + multiunzip(ROUND_LOG_SPLIT.map(|l| { + let (cur_inputs, r) = rest.split_at(1 << (log_size - LOG_N_LANES + l)); + rest = r; + air_round::generate_trace(log_size + l, cur_inputs, &mut xor_accums) + })); + + // Xor tables. + let (xor_trace12, xor_lookup_data12) = air_xor_table::generate_trace(xor_accums.xor12); + let (xor_trace9, xor_lookup_data9) = air_xor_table::generate_trace(xor_accums.xor9); + let (xor_trace8, xor_lookup_data8) = air_xor_table::generate_trace(xor_accums.xor8); + let (xor_trace7, xor_lookup_data7) = air_xor_table::generate_trace(xor_accums.xor7); + let (xor_trace4, xor_lookup_data4) = air_xor_table::generate_trace(xor_accums.xor4); + + // Claim. + let claim = BlakeClaim { log_size }; + claim.mix_into(channel); + + // Trace commitment. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ]); + tree_builder.commit(channel); + span.exit(); + + // Draw lookup element. + let all_elements = AllElements::draw(channel); + + // Interaction trace. + let span = span!(Level::INFO, "Interaction").entered(); + let mut lookup_input_layers = Vec::new(); + let mut mle_eval_at_point_collection = MleCollection::default(); + + lookup_input_layers.extend(scheduler::generate_lookup_instances( + log_size, + scheduler_lookup_data, + &all_elements.round_elements, + &all_elements.blake_elements, + &mut mle_eval_at_point_collection, + )); + + ROUND_LOG_SPLIT + .iter() + .zip(round_lookup_datas) + .for_each(|(l, lookup_data)| { + lookup_input_layers.extend(round::generate_lookup_instances( + log_size + l, + lookup_data, + &all_elements.xor_elements, + &all_elements.round_elements, + &mut mle_eval_at_point_collection, + )); + }); + + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data12, + &all_elements.xor_elements.xor12, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data9, + &all_elements.xor_elements.xor9, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data8, + &all_elements.xor_elements.xor8, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data7, + &all_elements.xor_elements.xor7, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data4, + &all_elements.xor_elements.xor4, + &mut mle_eval_at_point_collection, + )); + + let (_gkr_proof, gkr_artifact) = prove_batch(channel, lookup_input_layers); + let mle_acc_coeff = channel.draw_felt(); + let _mles = mle_eval_at_point_collection.random_linear_combine_by_n_variables(mle_acc_coeff); + let _claims = accumulate_claims_by_n_variables(mle_acc_coeff, &gkr_artifact); + span.exit(); + + #[cfg(test)] + for mle in &_mles { + let n_variables = mle.n_variables(); + let claim = _claims[&n_variables]; + let eval_point = gkr_artifact.ood_point(n_variables); + assert_eq!(mle.eval_at_point(eval_point), claim); + } + + todo!() +} + +fn accumulate_claims_by_n_variables( + random_coeff: SecureField, + GkrArtifact { + claims_to_verify_by_instance, + n_variables_by_instance, + .. + }: &GkrArtifact, +) -> BTreeMap { + let mut acc_by_n_variables = BTreeMap::new(); + + for (n_variables, claims) in zip(n_variables_by_instance, claims_to_verify_by_instance) { + let acc = acc_by_n_variables + .entry(n_variables) + .or_insert_with(|| PointEvaluationAccumulator::new(random_coeff)); + claims.iter().for_each(|claim| acc.accumulate(*claim)); + } + + acc_by_n_variables + .into_iter() + .map(|(&n_variables, acc)| (n_variables, acc.finalize())) + .collect() +} + +#[cfg(test)] +mod tests { + use std::env; + + use crate::core::pcs::PcsConfig; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::blake_gkr::air::prove_blake; + + // Note: this test is slow. Only run in release. + #[cfg_attr(not(feature = "slow-tests"), ignore)] + #[test_log::test] + fn test_simd_blake_gkr_prove() { + // Get from environment variable: + let log_n_instances = env::var("LOG_N_INSTANCES") + .unwrap_or_else(|_| "6".to_string()) + .parse::() + .unwrap(); + let config = PcsConfig::default(); + + // Prove. + let _proof = prove_blake::(log_n_instances, config); + + // Verify. + // verify_blake::(proof, config).unwrap(); + } +} diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs similarity index 98% rename from crates/prover/src/examples/xor/gkr_lookups/accumulation.rs rename to crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs index a63b62503..565dd07ae 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs @@ -146,7 +146,7 @@ mod tests { use crate::core::fields::qm31::SecureField; use crate::core::fields::Field; use crate::core::lookups::mle::{Mle, MleOps}; - use crate::examples::xor::gkr_lookups::accumulation::MleCollection; + use crate::examples::blake_gkr::gkr_lookups::accumulation::MleCollection; #[test] fn random_linear_combine_by_n_variables() { diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs similarity index 99% rename from crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs rename to crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs index 7f3805787..6b5cc1312 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs @@ -1,7 +1,4 @@ //! Multilinear extension (MLE) eval at point constraints. -// TODO(andrew): Remove in downstream PR. -#![allow(dead_code)] - use std::iter::zip; use itertools::{chain, zip_eq, Itertools}; @@ -733,8 +730,8 @@ mod tests { use crate::core::prover::{prove, verify, VerificationError}; use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; - use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; + use crate::examples::blake_gkr::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; + use crate::examples::blake_gkr::gkr_lookups::mle_eval::eval_step_selector_with_offset; #[test] fn mle_eval_prover_component() -> Result<(), VerificationError> { @@ -1112,7 +1109,7 @@ mod tests { use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::ColumnVec; - use crate::examples::xor::gkr_lookups::mle_eval::MleCoeffColumnOracle; + use crate::examples::blake_gkr::gkr_lookups::mle_eval::MleCoeffColumnOracle; pub type MleCoeffColumnComponent = FrameworkComponent; diff --git a/crates/prover/src/examples/xor/gkr_lookups/mod.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs similarity index 100% rename from crates/prover/src/examples/xor/gkr_lookups/mod.rs rename to crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs diff --git a/crates/prover/src/examples/blake_gkr/mod.rs b/crates/prover/src/examples/blake_gkr/mod.rs new file mode 100644 index 000000000..a1f767c6b --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/mod.rs @@ -0,0 +1,5 @@ +pub mod air; +pub mod gkr_lookups; +pub mod round; +pub mod scheduler; +pub mod xor_table; diff --git a/crates/prover/src/examples/blake_gkr/round.rs b/crates/prover/src/examples/blake_gkr/round.rs new file mode 100644 index 000000000..179f284dd --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/round.rs @@ -0,0 +1,49 @@ +use tracing::{span, Level}; + +use super::gkr_lookups::accumulation::MleCollection; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::mle::Mle; +use crate::examples::blake::round::{BlakeRoundLookupData, RoundElements}; +use crate::examples::blake::BlakeXorElements; + +pub fn generate_lookup_instances( + log_size: u32, + lookup_data: BlakeRoundLookupData, + xor_lookup_elements: &BlakeXorElements, + round_lookup_elements: &RoundElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Generate round interaction trace").entered(); + let size = 1 << log_size; + let mut round_lookup_layers = Vec::new(); + + for (w, l) in &lookup_data.xor_lookups { + let lookup_elements = xor_lookup_elements.get(*w); + let mut denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = lookup_elements.combine(&l.each_ref().map(|l| l.data[vec_row])); + denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { denominators }); + } + + // Blake round lookup. + let mut round_denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = round_lookup_elements + .combine(&lookup_data.round_lookup.each_ref().map(|l| l.data[vec_row])); + round_denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(round_denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { + denominators: round_denominators, + }); + + round_lookup_layers +} diff --git a/crates/prover/src/examples/blake_gkr/scheduler.rs b/crates/prover/src/examples/blake_gkr/scheduler.rs new file mode 100644 index 000000000..4dc2610d9 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/scheduler.rs @@ -0,0 +1,56 @@ +use tracing::{span, Level}; + +use super::gkr_lookups::accumulation::MleCollection; +use crate::core::backend::simd::column::{BaseColumn, SecureColumn}; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::mle::Mle; +use crate::examples::blake::round::RoundElements; +use crate::examples::blake::scheduler::{BlakeElements, BlakeSchedulerLookupData}; + +pub fn generate_lookup_instances( + log_size: u32, + lookup_data: BlakeSchedulerLookupData, + round_lookup_elements: &RoundElements, + blake_lookup_elements: &BlakeElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Generate scheduler interaction trace").entered(); + let size = 1 << log_size; + let mut round_lookup_layers = Vec::new(); + + for l0 in &lookup_data.round_lookups { + let mut denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); + denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { denominators }) + } + + // Blake hash lookup. + let blake_numers = Mle::::new(BaseColumn::zeros(size)); + let mut blake_denoms = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let blake_denom: PackedSecureField = blake_lookup_elements.combine( + &lookup_data + .blake_lookups + .each_ref() + .map(|l| l.data[vec_row]), + ); + blake_denoms.data[vec_row] = blake_denom; + } + collection_for_univariate_iop.push(blake_denoms.clone()); + round_lookup_layers.push(Layer::LogUpMultiplicities { + numerators: blake_numers, + denominators: blake_denoms, + }); + + round_lookup_layers +} diff --git a/crates/prover/src/examples/blake_gkr/xor_table.rs b/crates/prover/src/examples/blake_gkr/xor_table.rs new file mode 100644 index 000000000..611f701f5 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/xor_table.rs @@ -0,0 +1,65 @@ +use std::array; +use std::simd::u32x16; + +use tracing::{span, Level}; + +use super::gkr_lookups::accumulation::MleCollection; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::mle::Mle; +use crate::examples::blake::xor_table::{column_bits, limb_bits, XorElements, XorTableLookupData}; + +pub fn generate_lookup_instances( + lookup_data: XorTableLookupData, + lookup_elements: &XorElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Xor interaction trace").entered(); + let limb_bits = limb_bits::(); + let col_bits = column_bits::(); + let col_size = 1 << col_bits; + let offsets_vec = u32x16::from_array(array::from_fn(|i| i as u32)); + let mut xor_lookup_layers = Vec::new(); + + // There are 2^(2*EXPAND_BITS) columns, for each combination of ah, bh. + for (i, mults) in lookup_data.xor_accum.mults.iter().enumerate() { + let numerators = Mle::::new(mults.clone()); + let mut denominators = Mle::::new(SecureColumn::zeros(col_size)); + + // Extract ah, bh from column index. + let ah = i as u32 >> EXPAND_BITS; + let bh = i as u32 & ((1 << EXPAND_BITS) - 1); + + // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. + #[allow(clippy::needless_range_loop)] + for vec_row in 0..1 << (col_bits - LOG_N_LANES) { + // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. + // Extract al, blh from vec_row. + let al = vec_row >> (limb_bits - LOG_N_LANES); + let blh = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1); + + // Construct the 3 vectors a, b, c. + let a = u32x16::splat((ah << limb_bits) | al); + // bll is just the consecutive numbers 0..N_LANES-1. + let b = u32x16::splat((bh << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec; + let c = a ^ b; + + let denom = lookup_elements + .combine(&[a, b, c].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) })); + denominators.data[vec_row as usize] = denom; + } + + collection_for_univariate_iop.push(numerators.clone()); + xor_lookup_layers.push(Layer::LogUpMultiplicities { + numerators, + denominators, + }); + } + + xor_lookup_layers +} diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index c5e3a4eda..d4beaabe2 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,4 +1,4 @@ pub mod blake; +pub mod blake_gkr; pub mod plonk; pub mod poseidon; -pub mod xor; diff --git a/crates/prover/src/examples/xor/mod.rs b/crates/prover/src/examples/xor/mod.rs deleted file mode 100644 index 34e702a9b..000000000 --- a/crates/prover/src/examples/xor/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod gkr_lookups;