Skip to content

Commit

Permalink
Make CommitmentSchemeProver::prove_values take ownership
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Sep 25, 2024
1 parent a46c994 commit 64b9479
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 80 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ impl<B: FriOps + MerkleOps<H>, H: MerkleHasher> FriLayerProver<B, H> {
let commitment = self.merkle_tree.root();
// TODO(andrew): Use _evals.
let (_evals, decommitment) = self.merkle_tree.decommit(
[(self.evaluation.len().ilog2(), decommit_positions)]
&[(self.evaluation.len().ilog2(), decommit_positions)]
.into_iter()
.collect(),
self.evaluation.values.columns.iter().collect_vec(),
Expand Down
8 changes: 5 additions & 3 deletions crates/prover/src/core/pcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
}

pub fn prove_values(
&self,
self,
sampled_points: TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,
channel: &mut MC::C,
) -> CommitmentSchemeProof<MC::H> {
Expand Down Expand Up @@ -134,13 +134,14 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
.iter()
.map(|(&log_size, domain)| (log_size, domain.flatten()))
.collect();
tree.decommit(queries)
tree.decommit(&queries)
});

let queried_values = decommitment_results.as_ref().map(|(v, _)| v.clone());
let decommitments = decommitment_results.map(|(_, d)| d);

CommitmentSchemeProof {
commitments: self.roots(),
sampled_values,
decommitments,
queried_values,
Expand All @@ -152,6 +153,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,

#[derive(Debug, Serialize, Deserialize)]
pub struct CommitmentSchemeProof<H: MerkleHasher> {
pub commitments: TreeVec<H::Hash>,
pub sampled_values: TreeVec<ColumnVec<Vec<SecureField>>>,
pub decommitments: TreeVec<MerkleDecommitment<H>>,
pub queried_values: TreeVec<ColumnVec<Vec<BaseField>>>,
Expand Down Expand Up @@ -243,7 +245,7 @@ impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
/// positions on each column of that size.
fn decommit(
&self,
queries: BTreeMap<u32, Vec<usize>>,
queries: &BTreeMap<u32, Vec<usize>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<MC::H>) {
let eval_vec = self
.evaluations
Expand Down
103 changes: 51 additions & 52 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::ops::Deref;
use std::{array, mem};

use serde::{Deserialize, Serialize};
Expand All @@ -9,7 +10,7 @@ use super::backend::BackendForChannel;
use super::channel::MerkleChannel;
use super::fields::secure_column::SECURE_EXTENSION_DEGREE;
use super::fri::FriVerificationError;
use super::pcs::{CommitmentSchemeProof, TreeVec};
use super::pcs::CommitmentSchemeProof;
use super::vcs::ops::MerkleHasher;
use crate::core::channel::Channel;
use crate::core::circle::CirclePoint;
Expand All @@ -21,17 +22,11 @@ use crate::core::vcs::hash::Hash;
use crate::core::vcs::prover::MerkleDecommitment;
use crate::core::vcs::verifier::MerkleVerificationError;

#[derive(Debug, Serialize, Deserialize)]
pub struct StarkProof<H: MerkleHasher> {
pub commitments: TreeVec<H::Hash>,
pub commitment_scheme_proof: CommitmentSchemeProof<H>,
}

#[instrument(skip_all)]
pub fn prove<B: BackendForChannel<MC>, MC: MerkleChannel>(
components: &[&dyn ComponentProver<B>],
channel: &mut MC::C,
commitment_scheme: &mut CommitmentSchemeProver<'_, B, MC>,
mut commitment_scheme: CommitmentSchemeProver<'_, B, MC>,
) -> Result<StarkProof<MC::H>, ProvingError> {
let component_provers = ComponentProvers(components.to_vec());
let trace = commitment_scheme.trace();
Expand Down Expand Up @@ -59,25 +54,19 @@ pub fn prove<B: BackendForChannel<MC>, MC: MerkleChannel>(

// Prove the trace and composition OODS values, and retrieve them.
let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel);

let sampled_oods_values = &commitment_scheme_proof.sampled_values;
let composition_oods_eval = extract_composition_eval(sampled_oods_values).unwrap();
let proof = StarkProof(commitment_scheme_proof);
info!(proof_size_estimate = proof.size_estimate());

// Evaluate composition polynomial at OODS point and check that it matches the trace OODS
// values. This is a sanity check.
if composition_oods_eval
if proof.extract_composition_oods_eval().unwrap()
!= component_provers
.components()
.eval_composition_polynomial_at_point(oods_point, sampled_oods_values, random_coeff)
.eval_composition_polynomial_at_point(oods_point, &proof.sampled_values, random_coeff)
{
return Err(ProvingError::ConstraintsNotSatisfied);
}

let proof = StarkProof {
commitments: commitment_scheme.roots(),
commitment_scheme_proof,
};
info!(proof_size_estimate = proof.size_estimate());
Ok(proof)
}

Expand Down Expand Up @@ -105,42 +94,21 @@ pub fn verify<MC: MerkleChannel>(
// Add the composition polynomial mask points.
sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]);

let sampled_oods_values = &proof.commitment_scheme_proof.sampled_values;
let composition_oods_eval = extract_composition_eval(sampled_oods_values).map_err(|_| {
let composition_oods_eval = proof.extract_composition_oods_eval().map_err(|_| {
VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string())
})?;

if composition_oods_eval
!= components.eval_composition_polynomial_at_point(
oods_point,
sampled_oods_values,
&proof.sampled_values,
random_coeff,
)
{
return Err(VerificationError::OodsNotMatching);
}

commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel)
}

/// Extracts the composition trace evaluation from the mask.
fn extract_composition_eval(
mask: &TreeVec<Vec<Vec<SecureField>>>,
) -> Result<SecureField, InvalidOodsSampleStructure> {
let mut composition_cols = mask.last().into_iter().flatten();

let coordinate_evals = array::try_from_fn(|_| {
let col = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?;
let [eval] = col.try_into().map_err(|_| InvalidOodsSampleStructure)?;
Ok(eval)
})?;

// Too many columns.
if composition_cols.next().is_some() {
return Err(InvalidOodsSampleStructure);
}

Ok(SecureField::from_partial_evals(coordinate_evals))
commitment_scheme.verify_values(sample_points, proof.0, channel)
}

/// Error when the sampled values have an invalid structure.
Expand Down Expand Up @@ -172,20 +140,44 @@ pub enum VerificationError {
ProofOfWork,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct StarkProof<H: MerkleHasher>(pub CommitmentSchemeProof<H>);

impl<H: MerkleHasher> StarkProof<H> {
/// Extracts the composition trace Out-Of-Domain-Sample evaluation from the mask.
fn extract_composition_oods_eval(&self) -> Result<SecureField, InvalidOodsSampleStructure> {
// TODO(andrew): `[.., composition_mask, _quotients_mask]` when add quotients commitment.
let [.., composition_mask] = &**self.sampled_values else {
return Err(InvalidOodsSampleStructure);
};

let mut composition_cols = composition_mask.iter();

let coordinate_evals = array::try_from_fn(|_| {
let col = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?;
let [eval] = col.try_into().map_err(|_| InvalidOodsSampleStructure)?;
Ok(eval)
})?;

// Too many columns.
if composition_cols.next().is_some() {
return Err(InvalidOodsSampleStructure);
}

Ok(SecureField::from_partial_evals(coordinate_evals))
}

/// Returns the estimate size (in bytes) of the proof.
pub fn size_estimate(&self) -> usize {
SizeEstimate::size_estimate(self)
}

/// Returns size estimates (in bytes) for different parts of the proof.
pub fn size_breakdown_estimate(&self) -> StarkProofSizeBreakdown {
let Self {
commitments,
commitment_scheme_proof,
} = self;
let Self(commitment_scheme_proof) = self;

let CommitmentSchemeProof {
commitments,
sampled_values,
decommitments,
queried_values,
Expand Down Expand Up @@ -221,6 +213,14 @@ impl<H: MerkleHasher> StarkProof<H> {
}
}

impl<H: MerkleHasher> Deref for StarkProof<H> {
type Target = CommitmentSchemeProof<H>;

fn deref(&self) -> &CommitmentSchemeProof<H> {
&self.0
}
}

/// Size estimate (in bytes) for different parts of the proof.
pub struct StarkProofSizeBreakdown {
pub oods_samples: usize,
Expand Down Expand Up @@ -298,13 +298,15 @@ impl<H: MerkleHasher> SizeEstimate for FriProof<H> {
impl<H: MerkleHasher> SizeEstimate for CommitmentSchemeProof<H> {
fn size_estimate(&self) -> usize {
let Self {
commitments,
sampled_values,
decommitments,
queried_values,
proof_of_work,
fri_proof,
} = self;
sampled_values.size_estimate()
commitments.size_estimate()
+ sampled_values.size_estimate()
+ decommitments.size_estimate()
+ queried_values.size_estimate()
+ mem::size_of_val(proof_of_work)
Expand All @@ -314,11 +316,8 @@ impl<H: MerkleHasher> SizeEstimate for CommitmentSchemeProof<H> {

impl<H: MerkleHasher> SizeEstimate for StarkProof<H> {
fn size_estimate(&self) -> usize {
let Self {
commitments,
commitment_scheme_proof,
} = self;
commitments.size_estimate() + commitment_scheme_proof.size_estimate()
let Self(commitment_scheme_proof) = self;
commitment_scheme_proof.size_estimate()
}
}

Expand Down
11 changes: 1 addition & 10 deletions crates/prover/src/core/vcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,9 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
/// * A `MerkleDecommitment` containing the hash and column witnesses.
pub fn decommit(
&self,
queries_per_log_size: BTreeMap<u32, Vec<usize>>,
queries_per_log_size: &BTreeMap<u32, Vec<usize>>,
columns: Vec<&Col<B, BaseField>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<H>) {
// Check that queries are sorted and deduped.
// TODO(andrew): Consider using a Queries struct to prevent this.
for queries in queries_per_log_size.values() {
assert!(
queries.windows(2).all(|w| w[0] < w[1]),
"Queries are not sorted."
);
}

// Prepare output buffers.
let mut queried_values_by_layer = vec![];
let mut decommitment = MerkleDecommitment::empty();
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/vcs/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ where
queries.insert(log_size, layer_queries);
}

let (values, decommitment) = merkle.decommit(queries.clone(), cols.iter().collect_vec());
let (values, decommitment) = merkle.decommit(&queries, cols.iter().collect_vec());

let verifier = MerkleVerifier {
root: merkle.root(),
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ where

// Setup protocol.
let channel = &mut MC::C::default();
let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles);
let mut commitment_scheme = CommitmentSchemeProver::new(config, &twiddles);

let span = span!(Level::INFO, "Trace").entered();

Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ pub fn prove_fibonacci_plonk(

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);
let mut commitment_scheme =
CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Trace.
let span = span!(Level::INFO, "Trace").entered();
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ pub fn prove_poseidon(

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);
let mut commitment_scheme =
CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Trace.
let span = span!(Level::INFO, "Trace").entered();
Expand Down
12 changes: 4 additions & 8 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,8 @@ mod tests {

// Setup protocol.
let prover_channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(
config, &twiddles,
);
let mut commitment_scheme =
CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(config, &twiddles);

// Trace.
let trace = generate_test_trace(log_n_instances);
Expand Down Expand Up @@ -232,10 +230,8 @@ mod tests {

// Setup protocol.
let prover_channel = &mut Poseidon252Channel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<SimdBackend, Poseidon252MerkleChannel>::new(
config, &twiddles,
);
let mut commitment_scheme =
CommitmentSchemeProver::<SimdBackend, Poseidon252MerkleChannel>::new(config, &twiddles);

// Trace.
let trace = generate_test_trace(LOG_N_INSTANCES);
Expand Down

0 comments on commit 64b9479

Please sign in to comment.