Skip to content

Commit

Permalink
Fix poseidon sponge bug (#148)
Browse files Browse the repository at this point in the history
* demo bug

* fix collusion bug

cross test with independent implementation
  • Loading branch information
kilic authored Oct 11, 2024
1 parent 6b19555 commit 8fbf5ef
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 7 deletions.
7 changes: 4 additions & 3 deletions crypto-primitives/src/sponge/poseidon/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,13 @@ impl<F: PrimeField> PoseidonSpongeVar<F> {
..(self.parameters.capacity + num_elements_squeezed + rate_start_index)],
);

// Repeat with updated output slices and rate start index
remaining_output = &mut remaining_output[num_elements_squeezed..];

// Unless we are done with squeezing in this call, permute.
if remaining_output.len() != self.parameters.rate {
if !remaining_output.is_empty() {
self.permute()?;
}
// Repeat with updated output slices and rate start index
remaining_output = &mut remaining_output[num_elements_squeezed..];
rate_start_index = 0;
}
}
Expand Down
7 changes: 4 additions & 3 deletions crypto-primitives/src/sponge/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,13 @@ impl<F: PrimeField> PoseidonSponge<F> {
..(self.parameters.capacity + num_elements_squeezed + rate_start_index)],
);

// Repeat with updated output slices
output_remaining = &mut output_remaining[num_elements_squeezed..];
// Unless we are done with squeezing in this call, permute.
if output_remaining.len() != self.parameters.rate {
if !output_remaining.is_empty() {
self.permute();
}
// Repeat with updated output slices
output_remaining = &mut output_remaining[num_elements_squeezed..];

rate_start_index = 0;
}
}
Expand Down
235 changes: 234 additions & 1 deletion crypto-primitives/src/sponge/poseidon/tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,243 @@
use crate::sponge::poseidon::{PoseidonConfig, PoseidonSponge};
use crate::sponge::poseidon::{PoseidonConfig, PoseidonDefaultConfigField, PoseidonSponge};
use crate::sponge::test::Fr;
use crate::sponge::{Absorb, AbsorbWithLength, CryptographicSponge, FieldBasedCryptographicSponge};
use crate::{absorb, collect_sponge_bytes, collect_sponge_field_elements};
use ark_ff::{One, PrimeField, UniformRand};
use ark_std::test_rng;

#[test]
// Remove once this PR matures
fn demo_bug() {
let sponge_params = Fr::get_default_poseidon_parameters(2, false).unwrap();

let rng = &mut test_rng();
let input = (0..3).map(|_| Fr::rand(rng)).collect::<Vec<_>>();

// works good
let e0 = {
let mut sponge = PoseidonSponge::<Fr>::new(&sponge_params);
sponge.absorb(&input);
sponge.squeeze_native_field_elements(3)
};

// works good
let e1 = {
let mut sponge = PoseidonSponge::<Fr>::new(&sponge_params);
sponge.absorb(&input);
let e0 = sponge.squeeze_native_field_elements(1);
let e1 = sponge.squeeze_native_field_elements(1);
let e2 = sponge.squeeze_native_field_elements(1);
e0.iter()
.chain(e1.iter())
.chain(e2.iter())
.cloned()
.collect::<Vec<_>>()
};

// also works good
let e2 = {
let mut sponge = PoseidonSponge::<Fr>::new(&sponge_params);
sponge.absorb(&input);

let e0 = sponge.squeeze_native_field_elements(2);
let e1 = sponge.squeeze_native_field_elements(1);
e0.iter().chain(e1.iter()).cloned().collect::<Vec<_>>()
};

// skips a permutation if sponge
// * in squeezing mode
// * number of elements are equal to rate
let e3 = {
let mut sponge = PoseidonSponge::<Fr>::new(&sponge_params);
sponge.absorb(&input);
let e0 = sponge.squeeze_native_field_elements(1);
let e1 = sponge.squeeze_native_field_elements(2);
e0.iter().chain(e1.iter()).cloned().collect::<Vec<_>>()
};

assert_eq!(e0, e1);
assert_eq!(e0, e2);
assert_eq!(e0, e3); // this will fail
}

// Remove once this PR matures
fn run_cross_test<F: PrimeField + Absorb>(cfg: &PoseidonConfig<F>) {
#[derive(Debug, PartialEq, Eq)]
enum SpongeMode {
Absorbing,
Squeezing,
}

#[derive(Clone, Debug)]
struct Reference<F: PrimeField> {
cfg: PoseidonConfig<F>,
state: Vec<F>,
absorbing: Vec<F>,
squeeze_count: Option<usize>,
}

// workaround to permute a state
fn permute<F: PrimeField>(cfg: &PoseidonConfig<F>, state: &mut [F]) {
let mut sponge = PoseidonSponge::new(&cfg);
sponge.state.copy_from_slice(state);
sponge.permute();
state.copy_from_slice(&sponge.state)
}

impl<F: PrimeField> Reference<F> {
fn new(cfg: &PoseidonConfig<F>) -> Self {
let t = cfg.rate + cfg.capacity;
let state = vec![F::zero(); t];
Self {
cfg: cfg.clone(),
state,
absorbing: Vec::new(),
squeeze_count: None,
}
}

fn mode(&self) -> SpongeMode {
match self.squeeze_count {
Some(_) => {
assert!(self.absorbing.is_empty());
SpongeMode::Squeezing
}
None => SpongeMode::Absorbing,
}
}

fn absorb(&mut self, input: &[F]) {
if !input.is_empty() {
match self.mode() {
SpongeMode::Absorbing => self.absorbing.extend_from_slice(input),
SpongeMode::Squeezing => {
// Wash the state as mode changes
// This is not appied in SAFE sponge
permute(&self.cfg, &mut self.state);
// Append inputs to the absorbing line
self.absorbing.extend_from_slice(input);
// Change mode to absorbing
self.squeeze_count = None;
}
}
}
}

fn _absorb(&mut self) {
let rate = self.cfg.rate;
self.absorbing.chunks(rate).for_each(|chunk| {
self.state
.iter_mut()
.skip(self.cfg.capacity)
.zip(chunk.iter())
.for_each(|(s, c)| *s += *c);
permute(&self.cfg, &mut self.state);
});

// This case can only happen in the begining when the absorbing line is empty
// and user wants to squeeze elements. Notice that after moving to squueze mode
// if user calls absorb again with empty input it will be ignored
self.absorbing
.is_empty()
.then(|| permute(&self.cfg, &mut self.state));

// flush the absorbing line
self.absorbing.clear();

// Change to the squeezing mode
assert_eq!(self.mode(), SpongeMode::Absorbing);
self.squeeze_count = Some(0);
}

pub fn squeeze(&mut self, n: usize) -> Vec<F> {
match self.mode() {
SpongeMode::Absorbing => self._absorb(),
SpongeMode::Squeezing => {
assert!(self.absorbing.is_empty());
assert!(self.squeeze_count.is_some());

// ???
// **This seems nonsense to me**
// If,
// * number of squeeze is zero AND
// * in squeezing mode AND
// * output index is is at `rate`
// it applies a useless permutation.
// This is also not appied in SAFE sponge

if n == 0 {
let squeeze_count = self.squeeze_count.unwrap();
let out_index = self.squeeze_count.unwrap() % self.cfg.rate;
(out_index == 0 && squeeze_count != 0).then(|| {
permute(&self.cfg, &mut self.state);
self.squeeze_count = Some(0);
});
}
}
}

let rate = self.cfg.rate;
let mut output = Vec::new();
for _ in 0..n {
let squeeze_count = self.squeeze_count.unwrap();
let out_index = squeeze_count % rate;

// proceed with a permutation if
// * the rate is full
// * and it is not the first output
(out_index == 0 && squeeze_count != 0).then(|| permute(&self.cfg, &mut self.state));

// skip the capacity elements
let out_index = out_index + self.cfg.capacity;
output.push(self.state[out_index]);
self.squeeze_count.as_mut().map(|c| *c += 1);
}

output
}
}

let mut sponge = PoseidonSponge::new(cfg);
let mut sponge_ref = Reference::new(cfg);
let mut rng = test_rng();

for _ in 0..1000 {
let test = (0..100)
.map(|_| {
use crate::ark_std::rand::Rng;
let do_absorb = rng.gen_bool(0.5);
let do_squeeze = rng.gen_bool(0.5);

(
(do_absorb, rng.gen_range(0..=cfg.rate * 2 + 1)),
(do_squeeze, rng.gen_range(0..=cfg.rate * 2 + 1)),
)
})
.collect::<Vec<_>>();

// fuzz fuzz
for (_i, ((do_absorb, n_absorb), (do_squeeze, n_squeeze))) in test.into_iter().enumerate() {
do_absorb.then(|| {
let inputs = (0..n_absorb).map(|_| F::rand(&mut rng)).collect::<Vec<_>>();
sponge_ref.absorb(&inputs);
sponge.absorb(&inputs);
});
do_squeeze.then(|| {
let out0 = sponge_ref.squeeze(n_squeeze);
let out1 = sponge.squeeze_field_elements(n_squeeze);
assert_eq!(out0, out1);
});
}
}
}

#[test]
// Remove once this PR matures
fn test_cross() {
let cfg = Fr::get_default_poseidon_parameters(2, false).unwrap();
run_cross_test::<Fr>(&cfg);
}

fn assert_different_encodings<F: PrimeField, A: Absorb>(a: &A, b: &A) {
let bytes1 = a.to_sponge_bytes_as_vec();
let bytes2 = b.to_sponge_bytes_as_vec();
Expand Down

0 comments on commit 8fbf5ef

Please sign in to comment.