From fab4ec2a343993a4aa3a78d758dab1d5f3813ee9 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Mon, 28 Oct 2024 09:41:33 +0800 Subject: [PATCH 1/9] Snapshot --- ceno_zkvm/src/chip_handler/general.rs | 6 +- ceno_zkvm/src/chip_handler/register.rs | 10 +-- ceno_zkvm/src/expression.rs | 63 ++++++++------ ceno_zkvm/src/gadgets/is_lt.rs | 20 ++--- ceno_zkvm/src/gadgets/is_zero.rs | 9 +- ceno_zkvm/src/instructions/riscv/b_insn.rs | 17 ++-- .../src/instructions/riscv/ecall/halt.rs | 6 +- .../src/instructions/riscv/ecall_insn.rs | 12 +-- ceno_zkvm/src/instructions/riscv/i_insn.rs | 6 +- ceno_zkvm/src/instructions/riscv/im_insn.rs | 6 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 38 ++++----- ceno_zkvm/src/instructions/riscv/j_insn.rs | 6 +- .../src/instructions/riscv/jump/auipc.rs | 8 +- ceno_zkvm/src/instructions/riscv/jump/jal.rs | 2 +- ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 12 +-- .../src/instructions/riscv/memory/gadget.rs | 31 +++---- .../src/instructions/riscv/memory/store.rs | 4 +- ceno_zkvm/src/instructions/riscv/r_insn.rs | 8 +- ceno_zkvm/src/instructions/riscv/s_insn.rs | 6 +- ceno_zkvm/src/instructions/riscv/shift.rs | 6 +- ceno_zkvm/src/instructions/riscv/u_insn.rs | 4 +- ceno_zkvm/src/scheme/mock_prover.rs | 14 ++-- ceno_zkvm/src/scheme/tests.rs | 4 +- ceno_zkvm/src/scheme/utils.rs | 13 +-- ceno_zkvm/src/state.rs | 8 +- ceno_zkvm/src/tables/ops/ops_impl.rs | 2 +- ceno_zkvm/src/tables/program.rs | 9 +- ceno_zkvm/src/tables/ram/ram_impl.rs | 6 +- ceno_zkvm/src/tables/range/range_impl.rs | 2 +- ceno_zkvm/src/uint.rs | 31 +++---- ceno_zkvm/src/uint/arithmetic.rs | 84 +++++++++++-------- ceno_zkvm/src/uint/logic.rs | 7 +- ceno_zkvm/src/virtual_polys.rs | 7 +- 33 files changed, 258 insertions(+), 209 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 255ae20d3..1e068e1ce 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -370,12 +370,12 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.require_zero( || "is equal", - is_eq.expr().clone() * lhs.clone() - is_eq.expr() * rhs.clone(), + is_eq.expr_fnord().clone() * lhs.clone() - is_eq.expr_fnord() * rhs.clone(), )?; self.require_zero( || "is equal", - Expression::from(1) - is_eq.expr().clone() - diff_inverse.expr() * lhs - + diff_inverse.expr() * rhs, + Expression::from(1) - is_eq.expr_fnord().clone() - diff_inverse.expr_fnord() * lhs + + diff_inverse.expr_fnord() * rhs, )?; Ok((is_eq, diff_inverse)) diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index 254a8b6bc..ce01f7976 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -29,7 +29,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe vec![Expression::::Constant(E::BaseField::from( RAMType::Register as u64, ))], - vec![register_id.expr()], + vec![register_id.expr_fnord()], value.to_vec(), vec![prev_ts.clone()], ] @@ -41,7 +41,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe vec![Expression::::Constant(E::BaseField::from( RAMType::Register as u64, ))], - vec![register_id.expr()], + vec![register_id.expr_fnord()], value.to_vec(), vec![ts.clone()], ] @@ -74,7 +74,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe prev_values: RegisterExpr, value: RegisterExpr, ) -> Result<(Expression, AssertLTConfig), ZKVMError> { - assert!(register_id.expr().degree() <= 1); + assert!(register_id.expr_fnord().degree() <= 1); self.namespace(name_fn, |cb| { // READ (a, v, t) let read_record = cb.rlc_chip_record( @@ -82,7 +82,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe vec![Expression::::Constant(E::BaseField::from( RAMType::Register as u64, ))], - vec![register_id.expr()], + vec![register_id.expr_fnord()], prev_values.to_vec(), vec![prev_ts.clone()], ] @@ -94,7 +94,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe vec![Expression::::Constant(E::BaseField::from( RAMType::Register as u64, ))], - vec![register_id.expr()], + vec![register_id.expr_fnord()], value.to_vec(), vec![ts.clone()], ] diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 90f2bba97..ddca37206 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -623,7 +623,7 @@ impl WitIn { let name = name().into(); let wit = cb.create_witin(|| name.clone())?; if !debug { - cb.require_zero(|| name.clone(), wit.expr() - input)?; + cb.require_zero(|| name.clone(), wit.expr_fnord() - input)?; } Ok(wit) }, @@ -657,51 +657,65 @@ macro_rules! create_witin_from_expr { pub trait ToExpr { type Output; - fn expr(&self) -> Self::Output; + fn expr_fnord(&self) -> Self::Output; } impl ToExpr for WitIn { type Output = Expression; - fn expr(&self) -> Expression { + fn expr_fnord(&self) -> Expression { Expression::WitIn(self.id) } } impl ToExpr for &WitIn { type Output = Expression; - fn expr(&self) -> Expression { + fn expr_fnord(&self) -> Expression { Expression::WitIn(self.id) } } impl ToExpr for Fixed { type Output = Expression; - fn expr(&self) -> Expression { + fn expr_fnord(&self) -> Expression { Expression::Fixed(*self) } } impl ToExpr for &Fixed { type Output = Expression; - fn expr(&self) -> Expression { + fn expr_fnord(&self) -> Expression { Expression::Fixed(**self) } } impl ToExpr for Instance { type Output = Expression; - fn expr(&self) -> Expression { + fn expr_fnord(&self) -> Expression { Expression::Instance(*self) } } impl> ToExpr for F { type Output = Expression; - fn expr(&self) -> Expression { + fn expr_fnord(&self) -> Expression { Expression::Constant(*self) } } +macro_rules! impl_from_via_ToExpr { + ($($t:ty),*) => { + $( + impl From<$t> for Expression { + fn from(value: $t) -> Self { + value.expr_fnord() + } + } + )* + }; +} +impl_from_via_ToExpr!(WitIn, Fixed, Instance); +impl_from_via_ToExpr!(&WitIn, &Fixed, &Instance); + // Implement From trait for unsigned types of at most 64 bits macro_rules! impl_from_unsigned { ($($t:ty),*) => { @@ -894,8 +908,8 @@ mod tests { // scaledsum * challenge // 3 * x + 2 - let expr: Expression = - Into::>::into(3usize) * x.expr() + Into::>::into(2usize); + let expr: Expression = Into::>::into(3usize) * x.expr_fnord() + + Into::>::into(2usize); // c^3 + 1 let c = Expression::Challenge(0, 3, 1.into(), 1.into()); // res @@ -903,7 +917,7 @@ mod tests { assert_eq!( c * expr, Expression::ScaledSum( - Box::new(x.expr()), + Box::new(x.expr_fnord()), Box::new(Expression::Challenge(0, 3, 3.into(), 3.into())), Box::new(Expression::Challenge(0, 3, 2.into(), 2.into())) ) @@ -911,11 +925,11 @@ mod tests { // constant * witin // 3 * x - let expr: Expression = Into::>::into(3usize) * x.expr(); + let expr: Expression = Into::>::into(3usize) * x.expr_fnord(); assert_eq!( expr, Expression::ScaledSum( - Box::new(x.expr()), + Box::new(x.expr_fnord()), Box::new(Expression::Constant(3.into())), Box::new(Expression::Constant(0.into())) ) @@ -961,32 +975,33 @@ mod tests { let z = cb.create_witin(|| "z").unwrap(); // scaledsum * challenge // 3 * x + 2 - let expr: Expression = - Into::>::into(3usize) * x.expr() + Into::>::into(2usize); + let expr: Expression = Into::>::into(3usize) * x.expr_fnord() + + Into::>::into(2usize); assert!(expr.is_monomial_form()); // 2 product term - let expr: Expression = Into::>::into(3usize) * x.expr() * y.expr() - + Into::>::into(2usize) * x.expr(); + let expr: Expression = + Into::>::into(3usize) * x.expr_fnord() * y.expr_fnord() + + Into::>::into(2usize) * x.expr_fnord(); assert!(expr.is_monomial_form()); // complex linear operation // (2c + 3) * x * y - 6z let expr: Expression = - Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr() - - Into::>::into(6usize) * z.expr(); + Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr_fnord() * y.expr_fnord() + - Into::>::into(6usize) * z.expr_fnord(); assert!(expr.is_monomial_form()); // complex linear operation // (2c + 3) * x * y - 6z let expr: Expression = - Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr() - - Into::>::into(6usize) * z.expr(); + Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr_fnord() * y.expr_fnord() + - Into::>::into(6usize) * z.expr_fnord(); assert!(expr.is_monomial_form()); // complex linear operation // (2 * x + 3) * 3 + 6 * 8 - let expr: Expression = (Into::>::into(2usize) * x.expr() + let expr: Expression = (Into::>::into(2usize) * x.expr_fnord() + Into::>::into(3usize)) * Into::>::into(3usize) + Into::>::into(6usize) * Into::>::into(8usize); @@ -1002,8 +1017,8 @@ mod tests { let y = cb.create_witin(|| "y").unwrap(); // scaledsum * challenge // (x + 1) * (y + 1) - let expr: Expression = (Into::>::into(1usize) + x.expr()) - * (Into::>::into(2usize) + y.expr()); + let expr: Expression = (Into::>::into(1usize) + x.expr_fnord()) + * (Into::>::into(2usize) + y.expr_fnord()); assert!(!expr.is_monomial_form()); } diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index e35ea6b7a..84396602e 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -68,7 +68,7 @@ pub struct IsLtConfig { impl IsLtConfig { pub fn expr(&self) -> Expression { - self.is_lt.expr() + self.is_lt.expr_fnord() } pub fn construct_circuit< @@ -87,14 +87,14 @@ impl IsLtConfig { |cb| { let name = name_fn(); let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?; - cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; + cb.assert_bit(|| "is_lt_bit", is_lt.expr_fnord())?; let config = InnerLtConfig::construct_circuit( cb, name, lhs, rhs, - is_lt.expr(), + is_lt.expr_fnord(), max_num_u16_limbs, )?; Ok(Self { is_lt, config }) @@ -142,7 +142,7 @@ impl InnerLtConfig { || format!("var {var_name}"), |cb| { let witin = cb.create_witin(|| var_name.to_string())?; - cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?; + cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr_fnord())?; Ok(witin) }, ) @@ -155,7 +155,7 @@ impl InnerLtConfig { let pows = power_sequence((1 << u16::BITS).into()); let diff_expr = izip!(&diff, pows) - .map(|(record, beta)| beta * record.expr()) + .map(|(record, beta)| beta * record.expr_fnord()) .sum::>(); let range = Self::range(max_num_u16_limbs); @@ -264,7 +264,7 @@ pub struct SignedLtConfig { impl SignedLtConfig { pub fn expr(&self) -> Expression { - self.is_lt.expr() + self.is_lt.expr_fnord() } pub fn construct_circuit< @@ -282,9 +282,9 @@ impl SignedLtConfig { |cb| { let name = name_fn(); let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin"))?; - cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; + cb.assert_bit(|| "is_lt_bit", is_lt.expr_fnord())?; let config = - InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt.expr())?; + InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt.expr_fnord())?; Ok(SignedLtConfig { is_lt, config }) }, @@ -326,14 +326,14 @@ impl InnerSignedLtConfig { cb, || "lhs_msb", max_signed_limb_expr.clone(), - lhs.limbs.iter().last().unwrap().expr(), // msb limb + lhs.limbs.iter().last().unwrap().expr_fnord(), // msb limb 1, )?; let is_rhs_neg = IsLtConfig::construct_circuit( cb, || "rhs_msb", max_signed_limb_expr, - rhs.limbs.iter().last().unwrap().expr(), // msb limb + rhs.limbs.iter().last().unwrap().expr_fnord(), // msb limb 1, )?; diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs index 02994e4f7..fde48edfb 100644 --- a/ceno_zkvm/src/gadgets/is_zero.rs +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -17,7 +17,7 @@ pub struct IsZeroConfig { impl IsZeroConfig { pub fn expr(&self) -> Expression { - self.is_zero.expr() + self.is_zero.expr_fnord() } pub fn construct_circuit, N: FnOnce() -> NR>( @@ -30,10 +30,13 @@ impl IsZeroConfig { let inverse = cb.create_witin(|| "inv")?; // x==0 => is_zero=1 - cb.require_one(|| "is_zero_1", is_zero.expr() + x.clone() * inverse.expr())?; + cb.require_one( + || "is_zero_1", + is_zero.expr_fnord() + x.clone() * inverse.expr_fnord(), + )?; // x!=0 => is_zero=0 - cb.require_zero(|| "is_zero_0", is_zero.expr() * x.clone())?; + cb.require_zero(|| "is_zero_0", is_zero.expr_fnord() * x.clone())?; Ok(IsZeroConfig { is_zero, inverse }) }) diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index b7c74543f..ee7dcdf1a 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -61,23 +61,24 @@ impl BInstructionConfig { // Fetch instruction circuit_builder.lk_fetch(&InsnRecord::new( - vm_state.pc.expr(), + vm_state.pc.expr_fnord(), insn_kind.codes().opcode.into(), 0.into(), insn_kind.codes().func3.into(), - rs1.id.expr(), - rs2.id.expr(), - imm.expr(), + rs1.id.expr_fnord(), + rs2.id.expr_fnord(), + imm.expr_fnord(), ))?; // Branch program counter - let pc_offset = - branch_taken_bit.clone() * imm.expr() - branch_taken_bit * PC_STEP_SIZE + PC_STEP_SIZE; + let pc_offset = branch_taken_bit.clone() * imm.expr_fnord() + - branch_taken_bit * PC_STEP_SIZE + + PC_STEP_SIZE; let next_pc = vm_state.next_pc.unwrap(); circuit_builder.require_equal( || "pc_branch", - next_pc.expr(), - vm_state.pc.expr() + pc_offset, + next_pc.expr_fnord(), + vm_state.pc.expr_fnord() + pc_offset, )?; Ok(BInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index 21ba4de0c..2878e0493 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -37,7 +37,7 @@ impl Instruction for HaltInstruction { let prev_x10_ts = cb.create_witin(|| "prev_x10_ts")?; let exit_code = { let exit_code = cb.query_exit_code()?; - [exit_code[0].expr(), exit_code[1].expr()] + [exit_code[0].expr_fnord(), exit_code[1].expr_fnord()] }; let ecall_cfg = EcallInstructionConfig::construct_circuit( @@ -51,8 +51,8 @@ impl Instruction for HaltInstruction { let (_, lt_x10_cfg) = cb.register_read( || "read x10", E::BaseField::from(ceno_emul::CENO_PLATFORM.reg_arg0() as u64), - prev_x10_ts.expr(), - ecall_cfg.ts.expr() + Tracer::SUBCYCLE_RS2, + prev_x10_ts.expr_fnord(), + ecall_cfg.ts.expr_fnord() + Tracer::SUBCYCLE_RS2, exit_code, )?; diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index 49bc1d67a..70eb038b4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -31,14 +31,14 @@ impl EcallInstructionConfig { let pc = cb.create_witin(|| "pc")?; let ts = cb.create_witin(|| "cur_ts")?; - cb.state_in(pc.expr(), ts.expr())?; + cb.state_in(pc.expr_fnord(), ts.expr_fnord())?; cb.state_out( - next_pc.map_or(pc.expr() + PC_STEP_SIZE, |next_pc| next_pc), - ts.expr() + (Tracer::SUBCYCLES_PER_INSN as usize), + next_pc.map_or(pc.expr_fnord() + PC_STEP_SIZE, |next_pc| next_pc), + ts.expr_fnord() + (Tracer::SUBCYCLES_PER_INSN as usize), )?; cb.lk_fetch(&InsnRecord::new( - pc.expr(), + pc.expr_fnord(), (EANY.codes().opcode as usize).into(), 0.into(), (EANY.codes().func3 as usize).into(), @@ -53,8 +53,8 @@ impl EcallInstructionConfig { let (_, lt_x5_cfg) = cb.register_write( || "write x5", E::BaseField::from(CENO_PLATFORM.reg_ecall() as u64), - prev_x5_ts.expr(), - ts.expr() + Tracer::SUBCYCLE_RS1, + prev_x5_ts.expr_fnord(), + ts.expr_fnord() + Tracer::SUBCYCLE_RS1, syscall_id.clone(), syscall_ret_value.map_or(syscall_id, |v| v), )?; diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index 4a099477b..ffc78d4e4 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -44,11 +44,11 @@ impl IInstructionConfig { // Fetch the instruction. circuit_builder.lk_fetch(&InsnRecord::new( - vm_state.pc.expr(), + vm_state.pc.expr_fnord(), insn_kind.codes().opcode.into(), - rd.id.expr(), + rd.id.expr_fnord(), insn_kind.codes().func3.into(), - rs1.id.expr(), + rs1.id.expr_fnord(), 0.into(), imm.clone(), ))?; diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 0977dfa55..debbeb459 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -44,11 +44,11 @@ impl IMInstructionConfig { // Fetch the instruction circuit_builder.lk_fetch(&InsnRecord::new( - vm_state.pc.expr(), + vm_state.pc.expr_fnord(), (insn_kind.codes().opcode as usize).into(), - rd.id.expr(), + rd.id.expr_fnord(), (insn_kind.codes().func3 as usize).into(), - rs1.id.expr(), + rs1.id.expr_fnord(), 0.into(), imm.clone(), ))?; diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 5548786f5..3b8dfff24 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -40,13 +40,13 @@ impl StateInOut { let pc = circuit_builder.create_witin(|| "pc")?; let (next_pc_opt, next_pc_expr) = if branching { let next_pc = circuit_builder.create_witin(|| "next_pc")?; - (Some(next_pc), next_pc.expr()) + (Some(next_pc), next_pc.expr_fnord()) } else { - (None, pc.expr() + PC_STEP_SIZE) + (None, pc.expr_fnord() + PC_STEP_SIZE) }; let ts = circuit_builder.create_witin(|| "ts")?; - let next_ts = ts.expr() + Tracer::SUBCYCLES_PER_INSN; - circuit_builder.state_in(pc.expr(), ts.expr())?; + let next_ts = ts.expr_fnord() + Tracer::SUBCYCLES_PER_INSN; + circuit_builder.state_in(pc.expr_fnord(), ts.expr_fnord())?; circuit_builder.state_out(next_pc_expr, next_ts)?; Ok(StateInOut { @@ -92,8 +92,8 @@ impl ReadRS1 { let (_, lt_cfg) = circuit_builder.register_read( || "read_rs1", id, - prev_ts.expr(), - cur_ts.expr() + Tracer::SUBCYCLE_RS1, + prev_ts.expr_fnord(), + cur_ts.expr_fnord() + Tracer::SUBCYCLE_RS1, rs1_read, )?; @@ -147,8 +147,8 @@ impl ReadRS2 { let (_, lt_cfg) = circuit_builder.register_read( || "read_rs2", id, - prev_ts.expr(), - cur_ts.expr() + Tracer::SUBCYCLE_RS2, + prev_ts.expr_fnord(), + cur_ts.expr_fnord() + Tracer::SUBCYCLE_RS2, rs2_read, )?; @@ -203,8 +203,8 @@ impl WriteRD { let (_, lt_cfg) = circuit_builder.register_write( || "write_rd", id, - prev_ts.expr(), - cur_ts.expr() + Tracer::SUBCYCLE_RD, + prev_ts.expr_fnord(), + cur_ts.expr_fnord() + Tracer::SUBCYCLE_RD, prev_value.register_expr(), rd_written, )?; @@ -262,8 +262,8 @@ impl ReadMEM { let (_, lt_cfg) = circuit_builder.memory_read( || "read_memory", &mem_addr, - prev_ts.expr(), - cur_ts.expr() + Tracer::SUBCYCLE_MEM, + prev_ts.expr_fnord(), + cur_ts.expr_fnord() + Tracer::SUBCYCLE_MEM, mem_read, )?; @@ -318,8 +318,8 @@ impl WriteMEM { let (_, lt_cfg) = circuit_builder.memory_write( || "write_memory", &mem_addr, - prev_ts.expr(), - cur_ts.expr() + Tracer::SUBCYCLE_MEM, + prev_ts.expr_fnord(), + cur_ts.expr_fnord() + Tracer::SUBCYCLE_MEM, prev_value, new_value, )?; @@ -393,7 +393,7 @@ impl MemAddr { /// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1]. pub fn low_bit_exprs(&self) -> Vec> { iter::repeat_n(Expression::ZERO, self.n_zeros()) - .chain(self.low_bits.iter().map(ToExpr::expr)) + .chain(self.low_bits.iter().map(ToExpr::expr_fnord)) .collect() } @@ -403,13 +403,13 @@ impl MemAddr { // The address as two u16 limbs. // Soundness: This does not use the UInt range-check but specialized checks instead. let addr = UInt::new_unchecked(|| "memory_addr", cb)?; - let limbs = addr.expr(); + let limbs = addr.expr_fnord(); // Witness and constrain the non-zero low bits. let low_bits = (n_zeros..Self::N_LOW_BITS) .map(|i| { let bit = cb.create_witin(|| format!("addr_bit_{}", i))?; - cb.assert_bit(|| format!("addr_bit_{}", i), bit.expr())?; + cb.assert_bit(|| format!("addr_bit_{}", i), bit.expr_fnord())?; Ok(bit) }) .collect::, ZKVMError>>()?; @@ -417,14 +417,14 @@ impl MemAddr { // Express the value of the low bits. let low_sum: Expression = (n_zeros..Self::N_LOW_BITS) .zip_eq(low_bits.iter()) - .map(|(pos, bit)| bit.expr() * (1 << pos)) + .map(|(pos, bit)| bit.expr_fnord() * (1 << pos)) .sum(); // Range check the middle bits, that is the low limb excluding the low bits. let shift_right = E::BaseField::from(1 << Self::N_LOW_BITS) .invert() .unwrap() - .expr(); + .expr_fnord(); let mid_u14 = (limbs[0].clone() - low_sum) * shift_right; cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?; diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 1ffec5b99..5c672b131 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -40,13 +40,13 @@ impl JInstructionConfig { // Fetch instruction circuit_builder.lk_fetch(&InsnRecord::new( - vm_state.pc.expr(), + vm_state.pc.expr_fnord(), (insn_kind.codes().opcode as usize).into(), - rd.id.expr(), + rd.id.expr_fnord(), 0.into(), 0.into(), 0.into(), - vm_state.next_pc.unwrap().expr() - vm_state.pc.expr(), + vm_state.next_pc.unwrap().expr_fnord() - vm_state.pc.expr_fnord(), ))?; Ok(JInstructionConfig { vm_state, rd }) diff --git a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs b/ceno_zkvm/src/instructions/riscv/jump/auipc.rs index 6c979ec25..ea6a57f75 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/auipc.rs @@ -42,12 +42,12 @@ impl Instruction for AuipcInstruction { let u_insn = UInstructionConfig::construct_circuit( circuit_builder, InsnKind::AUIPC, - &imm.expr(), + &imm.expr_fnord(), rd_written.register_expr(), )?; let overflow_bit = circuit_builder.create_witin(|| "overflow_bit")?; - circuit_builder.assert_bit(|| "is_bit", overflow_bit.expr())?; + circuit_builder.assert_bit(|| "is_bit", overflow_bit.expr_fnord())?; // assert: imm + pc = rd_written + overflow_bit * 2^32 // valid formulation of mod 2^32 arithmetic because: @@ -55,8 +55,8 @@ impl Instruction for AuipcInstruction { // - rd_written is constrained to 4 bytes by UInt checked limbs circuit_builder.require_equal( || "imm+pc = rd_written+2^32*overflow", - imm.expr() + u_insn.vm_state.pc.expr(), - rd_written.value() + overflow_bit.expr() * (1u64 << 32), + imm.expr_fnord() + u_insn.vm_state.pc.expr_fnord(), + rd_written.value() + overflow_bit.expr_fnord() * (1u64 << 32), )?; Ok(AuipcConfig { diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal.rs b/ceno_zkvm/src/instructions/riscv/jump/jal.rs index 44facf944..c2feef9e0 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal.rs @@ -54,7 +54,7 @@ impl Instruction for JalInstruction { circuit_builder.require_equal( || "jal rd_written", rd_written.value(), - j_insn.vm_state.pc.expr() + PC_STEP_SIZE, + j_insn.vm_state.pc.expr_fnord() + PC_STEP_SIZE, )?; Ok(JalConfig { j_insn, rd_written }) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 889f3eca8..cf8efb5ef 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -50,7 +50,7 @@ impl Instruction for JalrInstruction { let i_insn = IInstructionConfig::construct_circuit( circuit_builder, InsnKind::JALR, - &imm.expr(), + &imm.expr_fnord(), rs1_read.register_expr(), rd_written.register_expr(), true, @@ -67,26 +67,26 @@ impl Instruction for JalrInstruction { circuit_builder.require_equal( || "rs1+imm = next_pc_unrounded + overflow*2^32", - rs1_read.value() + imm.expr(), - next_pc_addr.expr_unaligned() + overflow.expr() * (1u64 << 32), + rs1_read.value() + imm.expr_fnord(), + next_pc_addr.expr_unaligned() + overflow.expr_fnord() * (1u64 << 32), )?; circuit_builder.require_zero( || "overflow_0_or_pm1", - overflow.expr() * (overflow.expr() - 1) * (overflow.expr() + 1), + overflow.expr_fnord() * (overflow.expr_fnord() - 1) * (overflow.expr_fnord() + 1), )?; circuit_builder.require_equal( || "next_pc_addr = next_pc", next_pc_addr.expr_align2(), - i_insn.vm_state.next_pc.unwrap().expr(), + i_insn.vm_state.next_pc.unwrap().expr_fnord(), )?; // write pc+4 to rd circuit_builder.require_equal( || "rd_written = pc+4", rd_written.value(), - i_insn.vm_state.pc.expr() + PC_STEP_SIZE, + i_insn.vm_state.pc.expr_fnord() + PC_STEP_SIZE, )?; Ok(JalrConfig { diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 0980253db..603d2a60f 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -34,7 +34,7 @@ impl MemWordChange { (0..num_bytes) .map(|i| { let byte = cb.create_witin(|| format!("{}.le_bytes[{}]", anno, i))?; - cb.assert_ux::<_, _, 8>(|| "byte range check", byte.expr())?; + cb.assert_ux::<_, _, 8>(|| "byte range check", byte.expr_fnord())?; Ok(byte) }) @@ -54,7 +54,7 @@ impl MemWordChange { bytes .iter() .enumerate() - .map(|(idx, byte)| (1 << (idx * 8)) * byte.expr()) + .map(|(idx, byte)| (1 << (idx * 8)) * byte.expr_fnord()) .sum(), )?; @@ -68,8 +68,8 @@ impl MemWordChange { assert!(prev_word.wits_in().is_some() && rs2_word.wits_in().is_some()); let low_bits = addr.low_bit_exprs(); - let prev_limbs = prev_word.expr(); - let rs2_limbs = rs2_word.expr(); + let prev_limbs = prev_word.expr_fnord(); + let rs2_limbs = rs2_word.expr_fnord(); // degree 2 expression let prev_target_limb = cb.select(&low_bits[1], &prev_limbs[1], &prev_limbs[0]); @@ -80,7 +80,8 @@ impl MemWordChange { let u8_base_inv = E::BaseField::from(1 << 8).invert().unwrap(); cb.assert_ux::<_, _, 8>( || "rs2_limb[0].le_bytes[1]", - u8_base_inv.expr() * (rs2_limbs[0].clone() - rs2_limb_bytes[0].expr()), + u8_base_inv.expr_fnord() + * (rs2_limbs[0].clone() - rs2_limb_bytes[0].expr_fnord()), )?; // alloc a new witIn to cache degree 2 expression @@ -88,9 +89,9 @@ impl MemWordChange { cb.condition_require_equal( || "expected_limb_change = select(low_bits[0], rs2 - prev)", low_bits[0].clone(), - expected_limb_change.expr(), - (1 << 8) * (rs2_limb_bytes[0].expr() - prev_limb_bytes[1].expr()), - rs2_limb_bytes[0].expr() - prev_limb_bytes[0].expr(), + expected_limb_change.expr_fnord(), + (1 << 8) * (rs2_limb_bytes[0].expr_fnord() - prev_limb_bytes[1].expr_fnord()), + rs2_limb_bytes[0].expr_fnord() - prev_limb_bytes[0].expr_fnord(), )?; // alloc a new witIn to cache degree 2 expression @@ -98,9 +99,9 @@ impl MemWordChange { cb.condition_require_equal( || "expected_change = select(low_bits[1], limb_change*2^16, limb_change)", low_bits[1].clone(), - expected_change.expr(), - (1 << 16) * expected_limb_change.expr(), - expected_limb_change.expr(), + expected_change.expr_fnord(), + (1 << 16) * expected_limb_change.expr_fnord(), + expected_limb_change.expr_fnord(), )?; Ok(MemWordChange { @@ -114,8 +115,8 @@ impl MemWordChange { assert!(prev_word.wits_in().is_some() && rs2_word.wits_in().is_some()); let low_bits = addr.low_bit_exprs(); - let prev_limbs = prev_word.expr(); - let rs2_limbs = rs2_word.expr(); + let prev_limbs = prev_word.expr_fnord(); + let rs2_limbs = rs2_word.expr_fnord(); let expected_change = cb.create_witin(|| "expected_change")?; @@ -124,7 +125,7 @@ impl MemWordChange { || "expected_change = select(low_bits[1], 2^16*(limb_change))", // degree 2 expression low_bits[1].clone(), - expected_change.expr(), + expected_change.expr_fnord(), (1 << 16) * (rs2_limbs[0].clone() - prev_limbs[1].clone()), rs2_limbs[0].clone() - prev_limbs[0].clone(), )?; @@ -142,7 +143,7 @@ impl MemWordChange { pub(crate) fn value(&self) -> Expression { assert!(N_ZEROS <= 1); - self.expected_changes[1 - N_ZEROS].expr() + self.expected_changes[1 - N_ZEROS].expr_fnord() } pub fn assign_instance( diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs index fc8f0455f..71ebc5a88 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -97,7 +97,7 @@ impl Instruction circuit_builder.require_equal( || "memory_addr = rs1_read + imm", memory_addr.expr_unaligned(), - rs1_read.value() + imm.expr(), + rs1_read.value() + imm.expr_fnord(), )?; let (new_memory_value, word_change) = match I::INST_KIND { @@ -117,7 +117,7 @@ impl Instruction let s_insn = SInstructionConfig::::construct_circuit( circuit_builder, I::INST_KIND, - &imm.expr(), + &imm.expr_fnord(), rs1_read.register_expr(), rs2_read.register_expr(), memory_addr.expr_align4(), diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index c5a19cbac..853fbe1f4 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -43,12 +43,12 @@ impl RInstructionConfig { // Fetch instruction circuit_builder.lk_fetch(&InsnRecord::new( - vm_state.pc.expr(), + vm_state.pc.expr_fnord(), insn_kind.codes().opcode.into(), - rd.id.expr(), + rd.id.expr_fnord(), insn_kind.codes().func3.into(), - rs1.id.expr(), - rs2.id.expr(), + rs1.id.expr_fnord(), + rs2.id.expr_fnord(), insn_kind.codes().func7.into(), ))?; diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index 702a3d5ea..e6723e6ca 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -43,12 +43,12 @@ impl SInstructionConfig { // Fetch instruction circuit_builder.lk_fetch(&InsnRecord::new( - vm_state.pc.expr(), + vm_state.pc.expr_fnord(), (insn_kind.codes().opcode as usize).into(), 0.into(), (insn_kind.codes().func3 as usize).into(), - rs1.id.expr(), - rs2.id.expr(), + rs1.id.expr_fnord(), + rs2.id.expr_fnord(), imm.clone(), ))?; diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 189811dd2..2d056b7de 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -99,12 +99,12 @@ impl Instruction for ShiftLogicalInstru rd_written.register_expr(), )?; - circuit_builder.lookup_pow2(rs2_low5.expr(), pow2_rs2_low5.value())?; - circuit_builder.assert_ux::<_, _, 5>(|| "rs2_low5 in u5", rs2_low5.expr())?; + circuit_builder.lookup_pow2(rs2_low5.expr_fnord(), pow2_rs2_low5.value())?; + circuit_builder.assert_ux::<_, _, 5>(|| "rs2_low5 in u5", rs2_low5.expr_fnord())?; circuit_builder.require_equal( || "rs2 == rs2_high * 2^5 + rs2_low5", rs2_read.value(), - rs2_high.value() * (1 << 5) + rs2_low5.expr(), + rs2_high.value() * (1 << 5) + rs2_low5.expr_fnord(), )?; Ok(ShiftConfig { diff --git a/ceno_zkvm/src/instructions/riscv/u_insn.rs b/ceno_zkvm/src/instructions/riscv/u_insn.rs index 1c6d7040e..a2c22cfae 100644 --- a/ceno_zkvm/src/instructions/riscv/u_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/u_insn.rs @@ -38,9 +38,9 @@ impl UInstructionConfig { // Fetch instruction circuit_builder.lk_fetch(&InsnRecord::new( - vm_state.pc.expr(), + vm_state.pc.expr_fnord(), (insn_kind.codes().opcode as usize).into(), - rd.id.expr(), + rd.id.expr_fnord(), 0.into(), 0.into(), 0.into(), diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 6c325573a..c26d1b51d 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -717,14 +717,14 @@ mod tests { let c = cb.create_witin(|| "c")?; // degree 1 - cb.require_equal(|| "a + 1 == b", b.expr(), a.expr() + 1)?; - cb.require_zero(|| "c - 2 == 0", c.expr() - 2)?; + cb.require_equal(|| "a + 1 == b", b.expr_fnord(), a.expr_fnord() + 1)?; + cb.require_zero(|| "c - 2 == 0", c.expr_fnord() - 2)?; // degree > 1 let d = cb.create_witin(|| "d")?; cb.require_zero( || "d*d - 6*d + 9 == 0", - d.expr() * d.expr() - d.expr() * 6 + 9, + d.expr_fnord() * d.expr_fnord() - d.expr_fnord() * 6 + 9, )?; Ok(Self { a, b, c }) @@ -767,7 +767,7 @@ mod tests { cb: &mut CircuitBuilder, ) -> Result { let a = cb.create_witin(|| "a")?; - cb.assert_ux::<_, _, 5>(|| "assert u5", a.expr())?; + cb.assert_ux::<_, _, 5>(|| "assert u5", a.expr_fnord())?; Ok(Self { a }) } } @@ -849,7 +849,8 @@ mod tests { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a")?; let b = cb.create_witin(|| "b")?; - let lt_wtns = AssertLTConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; + let lt_wtns = + AssertLTConfig::construct_circuit(cb, || "lt", a.expr_fnord(), b.expr_fnord(), 1)?; Ok(Self { a, b, lt_wtns }) } @@ -973,7 +974,8 @@ mod tests { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a")?; let b = cb.create_witin(|| "b")?; - let lt_wtns = IsLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; + let lt_wtns = + IsLtConfig::construct_circuit(cb, || "lt", a.expr_fnord(), b.expr_fnord(), 1)?; Ok(Self { a, b, lt_wtns }) } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 80c43af0b..f4e998681 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -54,14 +54,14 @@ impl Instruction for Test (0..RW).try_for_each(|_| { let record = cb.rlc_chip_record(vec![ Expression::::Constant(E::BaseField::ONE), - reg_id.expr(), + reg_id.expr_fnord(), ]); cb.read_record(|| "read", record.clone())?; cb.write_record(|| "write", record)?; Result::<(), ZKVMError>::Ok(()) })?; (0..L).try_for_each(|_| { - cb.assert_ux::<_, _, 16>(|| "regid_in_range", reg_id.expr())?; + cb.assert_ux::<_, _, 16>(|| "regid_in_range", reg_id.expr_fnord())?; Result::<(), ZKVMError>::Ok(()) })?; assert_eq!(cb.cs.lk_expressions.len(), L); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index e09e70574..9058e69e8 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -677,7 +677,10 @@ mod tests { let b = cb.create_witin(|| "b").unwrap(); let c = cb.create_witin(|| "c").unwrap(); - let expr: Expression = a.expr() + b.expr() + a.expr() * b.expr() + (c.expr() * 3 + 2); + let expr: Expression = a.expr_fnord() + + b.expr_fnord() + + a.expr_fnord() * b.expr_fnord() + + (c.expr_fnord() * 3 + 2); let res = wit_infer_by_expr( &[], @@ -703,10 +706,10 @@ mod tests { let b = cb.create_witin(|| "b").unwrap(); let c = cb.create_witin(|| "c").unwrap(); - let expr: Expression = a.expr() - + b.expr() - + a.expr() * b.expr() - + (c.expr() * 3 + 2) + let expr: Expression = a.expr_fnord() + + b.expr_fnord() + + a.expr_fnord() * b.expr_fnord() + + (c.expr_fnord() * 3 + 2) + Expression::Challenge(0, 1, E::ONE, E::ONE); let res = wit_infer_by_expr( diff --git a/ceno_zkvm/src/state.rs b/ceno_zkvm/src/state.rs index 875e8fbfb..d3e173bb2 100644 --- a/ceno_zkvm/src/state.rs +++ b/ceno_zkvm/src/state.rs @@ -24,8 +24,8 @@ impl StateCircuit for GlobalState { ) -> Result, ZKVMError> { let states: Vec> = vec![ Expression::Constant(E::BaseField::from(RAMType::GlobalState as u64)), - circuit_builder.query_init_pc()?.expr(), - circuit_builder.query_init_cycle()?.expr(), + circuit_builder.query_init_pc()?.expr_fnord(), + circuit_builder.query_init_cycle()?.expr_fnord(), ]; Ok(circuit_builder.rlc_chip_record(states)) @@ -36,8 +36,8 @@ impl StateCircuit for GlobalState { ) -> Result, crate::error::ZKVMError> { let states: Vec> = vec![ Expression::Constant(E::BaseField::from(RAMType::GlobalState as u64)), - circuit_builder.query_end_pc()?.expr(), - circuit_builder.query_end_cycle()?.expr(), + circuit_builder.query_end_pc()?.expr_fnord(), + circuit_builder.query_end_cycle()?.expr_fnord(), ]; Ok(circuit_builder.rlc_chip_record(states)) diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index f9bf4d0c8..b1b127577 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -41,7 +41,7 @@ impl OpTableConfig { Expression::Fixed(abc[2]), ]); - cb.lk_table_record(|| "record", table_len, rlc_record, mlt.expr())?; + cb.lk_table_record(|| "record", table_len, rlc_record, mlt.expr_fnord())?; Ok(Self { abc, mlt }) } diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 3514365c8..a78800604 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -138,12 +138,17 @@ impl TableCircuit let mlt = cb.create_witin(|| "mlt")?; let record_exprs = { - let mut fields = vec![E::BaseField::from(ROMType::Instruction as u64).expr()]; + let mut fields = vec![E::BaseField::from(ROMType::Instruction as u64).expr_fnord()]; fields.extend(record.as_slice().iter().map(|f| Expression::Fixed(*f))); cb.rlc_chip_record(fields) }; - cb.lk_table_record(|| "prog table", PROGRAM_SIZE, record_exprs, mlt.expr())?; + cb.lk_table_record( + || "prog table", + PROGRAM_SIZE, + record_exprs, + mlt.expr_fnord(), + )?; Ok(ProgramTableConfig { record, mlt }) } diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index a4d263123..703bfa391 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -49,7 +49,7 @@ impl RamTableConfig { [ vec![(RAM::RAM_TYPE as usize).into()], vec![Expression::Fixed(addr)], - init_v.iter().map(|v| v.expr()).collect_vec(), + init_v.iter().map(|v| v.expr_fnord()).collect_vec(), vec![Expression::ZERO], // Initial cycle. ] .concat(), @@ -60,8 +60,8 @@ impl RamTableConfig { // a v t vec![(RAM::RAM_TYPE as usize).into()], vec![Expression::Fixed(addr)], - final_v.iter().map(|v| v.expr()).collect_vec(), - vec![final_cycle.expr()], + final_v.iter().map(|v| v.expr_fnord()).collect_vec(), + vec![final_cycle.expr_fnord()], ] .concat(), ); diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index 8a14fe236..27dab43c2 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -33,7 +33,7 @@ impl RangeTableConfig { let rlc_record = cb.rlc_chip_record(vec![(rom_type as usize).into(), Expression::Fixed(fixed)]); - cb.lk_table_record(|| "record", table_len, rlc_record, mlt.expr())?; + cb.lk_table_record(|| "record", table_len, rlc_record, mlt.expr_fnord())?; Ok(Self { fixed, mlt }) } diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index fef0c80bc..0735baf98 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -92,7 +92,10 @@ impl UIntLimbs { .map(|i| { let w = cb.create_witin(|| format!("limb_{i}"))?; if is_check { - cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; + cb.assert_ux::<_, _, C>( + || format!("limb_{i}_in_{C}"), + w.expr_fnord(), + )?; } // skip range check Ok(w) @@ -164,12 +167,12 @@ impl UIntLimbs { .map(|i| { let w = circuit_builder.create_witin(|| "wit for limb").unwrap(); circuit_builder - .assert_ux::<_, _, C>(|| "range check", w.expr()) + .assert_ux::<_, _, C>(|| "range check", w.expr_fnord()) .unwrap(); circuit_builder .require_zero( || "create_witin_from_expr", - w.expr() - expr_limbs[i].clone(), + w.expr_fnord() - expr_limbs[i].clone(), ) .unwrap(); w @@ -297,7 +300,7 @@ impl UIntLimbs { chunk .iter() .zip(shift_pows.iter()) - .map(|(limb, shift)| shift.clone() * limb.expr()) + .map(|(limb, shift)| shift.clone() * limb.expr_fnord()) .reduce(|a, b| a + b) .unwrap() }) @@ -325,8 +328,8 @@ impl UIntLimbs { let limbs = (0..k) .map(|_| { let w = circuit_builder.create_witin(|| "").unwrap(); - circuit_builder.assert_byte(|| "", w.expr()).unwrap(); - w.expr() + circuit_builder.assert_byte(|| "", w.expr_fnord()).unwrap(); + w.expr_fnord() }) .collect_vec(); let combined_limb = limbs @@ -337,7 +340,7 @@ impl UIntLimbs { .unwrap(); circuit_builder - .require_zero(|| "zero check", large_limb.expr() - combined_limb) + .require_zero(|| "zero check", large_limb.expr_fnord() - combined_limb) .unwrap(); limbs }) @@ -372,7 +375,7 @@ impl UIntLimbs { (0..Self::NUM_LIMBS) .map(|i| { let w = cb.create_witin(|| format!("limb_{i}"))?; - cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; + cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr_fnord())?; Ok(w) }) .collect::, ZKVMError>>()?, @@ -508,7 +511,7 @@ impl UIntLimbs { /// Get an Expression from the limbs, unsafe if Uint value exceeds field limit pub fn value(&self) -> Expression { let base = Expression::from(1 << C); - self.expr() + self.expr_fnord() .into_iter() .rev() .reduce(|sum, limb| sum * base.clone() + limb) @@ -520,7 +523,7 @@ impl UIntLimbs { &self, ) -> Result<(UIntLimbs, UIntLimbs), ZKVMError> { assert!(M == 2 * M2); - let mut self_lo = self.expr(); + let mut self_lo = self.expr_fnord(); let self_hi = self_lo.split_off(self_lo.len() / 2); Ok(( UIntLimbs::from_exprs_unchecked(self_lo)?, @@ -563,11 +566,11 @@ impl TryFrom<&[WitIn]> for UI impl ToExpr for UIntLimbs { type Output = Vec>; - fn expr(&self) -> Vec> { + fn expr_fnord(&self) -> Vec> { match &self.limbs { UintLimb::WitIn(limbs) => limbs .iter() - .map(ToExpr::expr) + .map(ToExpr::expr_fnord) .collect::>>(), UintLimb::Expression(e) => e.clone(), } @@ -577,7 +580,7 @@ impl ToExpr for UIntLimbs< impl UIntLimbs<32, 16, E> { /// Return a value suitable for register read/write. From [u16; 2] limbs. pub fn register_expr(&self) -> RegisterExpr { - let u16_limbs = self.expr(); + let u16_limbs = self.expr_fnord(); u16_limbs.try_into().expect("two limbs with M=32 and C=16") } @@ -595,7 +598,7 @@ impl UIntLimbs<32, 16, E> { impl UIntLimbs<32, 8, E> { /// Return a value suitable for register read/write. From [u8; 4] limbs. pub fn register_expr(&self) -> RegisterExpr { - let u8_limbs = self.expr(); + let u8_limbs = self.expr_fnord(); let u16_limbs = u8_limbs .chunks(2) .map(|chunk| { diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 3ce3b65f6..d1d4c283f 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -35,13 +35,13 @@ impl UIntLimbs { return Err(ZKVMError::CircuitError); }; carries.iter().enumerate().try_for_each(|(i, carry)| { - circuit_builder.assert_bit(|| format!("carry_{i}_in_as_bit"), carry.expr()) + circuit_builder.assert_bit(|| format!("carry_{i}_in_as_bit"), carry.expr_fnord()) })?; // perform add operation // c[i] = a[i] + b[i] + carry[i-1] - carry[i] * 2 ^ C c.limbs = UintLimb::Expression( - (self.expr()) + (self.expr_fnord()) .iter() .zip((*addend).iter()) .enumerate() @@ -52,10 +52,11 @@ impl UIntLimbs { let mut limb_expr = a.clone() + b.clone(); if carry.is_some() { - limb_expr = limb_expr.clone() + carry.unwrap().expr(); + limb_expr = limb_expr.clone() + carry.unwrap().expr_fnord(); } if next_carry.is_some() { - limb_expr = limb_expr.clone() - next_carry.unwrap().expr() * Self::POW_OF_C; + limb_expr = + limb_expr.clone() - next_carry.unwrap().expr_fnord() * Self::POW_OF_C; } circuit_builder @@ -101,7 +102,7 @@ impl UIntLimbs { with_overflow: bool, ) -> Result, ZKVMError> { circuit_builder.namespace(name_fn, |cb| { - self.internal_add(cb, &addend.expr(), with_overflow) + self.internal_add(cb, &addend.expr_fnord(), with_overflow) }) } @@ -121,7 +122,8 @@ impl UIntLimbs { // with high limb, overall cell will be double let c_limbs: Vec = (0..num_limbs).try_fold(vec![], |mut c_limbs, i| { let limb = circuit_builder.create_witin(|| format!("limb_{i}"))?; - circuit_builder.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb.expr())?; + circuit_builder + .assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb.expr_fnord())?; c_limbs.push(limb); Result::, ZKVMError>::Ok(c_limbs) })?; @@ -141,7 +143,7 @@ impl UIntLimbs { AssertLTConfig::construct_circuit( circuit_builder, || format!("carry_{i}_in_less_than"), - carry.expr(), + carry.expr_fnord(), (Self::MAX_DEGREE_2_MUL_CARRY_VALUE as usize).into(), Self::MAX_DEGREE_2_MUL_CARRY_U16_LIMB, ) @@ -156,18 +158,18 @@ impl UIntLimbs { circuit_builder.namespace( || name.to_owned(), |cb| { - let existing_expr = u.expr(); + let existing_expr = u.expr_fnord(); // this will overwrite existing expressions u.replace_limbs_with_witin(|| "replace_limbs_with_witin".to_string(), cb)?; // check if the new witness equals the existing expression - izip!(u.expr(), existing_expr).try_for_each(|(lhs, rhs)| { + izip!(u.expr_fnord(), existing_expr).try_for_each(|(lhs, rhs)| { cb.require_equal(|| "new_witin_equal_expr".to_string(), lhs, rhs) })?; Ok(()) }, )?; } - Ok(u.expr()) + Ok(u.expr_fnord()) }; let a_expr = swap_witin("lhs", self)?; @@ -193,12 +195,13 @@ impl UIntLimbs { c_limbs.iter().enumerate().try_for_each(|(i, c_limb)| { let carry = if i > 0 { c_carries.get(i - 1) } else { None }; let next_carry = c_carries.get(i); - result_c[i] = result_c[i].clone() - c_limb.expr(); + result_c[i] = result_c[i].clone() - c_limb.expr_fnord(); if carry.is_some() { - result_c[i] = result_c[i].clone() + carry.unwrap().expr(); + result_c[i] = result_c[i].clone() + carry.unwrap().expr_fnord(); } if next_carry.is_some() { - result_c[i] = result_c[i].clone() - next_carry.unwrap().expr() * Self::POW_OF_C; + result_c[i] = + result_c[i].clone() - next_carry.unwrap().expr_fnord() * Self::POW_OF_C; } circuit_builder.require_zero(|| format!("mul_zero_{i}"), result_c[i].clone())?; Ok::<(), ZKVMError>(()) @@ -242,11 +245,11 @@ impl UIntLimbs { mul_hi } else { // lo limb - UIntLimbs::from_exprs_unchecked(mul.expr())? + UIntLimbs::from_exprs_unchecked(mul.expr_fnord())? }; let add = cb.namespace( || "add", - |cb| mul_lo_or_hi.internal_add(cb, &addend.expr(), with_overflow), + |cb| mul_lo_or_hi.internal_add(cb, &addend.expr_fnord(), with_overflow), )?; Ok((add, mul)) }) @@ -272,18 +275,20 @@ impl UIntLimbs { .limbs .iter() .zip_eq(rhs.limbs.iter()) - .map(|(a, b)| circuit_builder.is_equal(a.expr(), b.expr())) + .map(|(a, b)| circuit_builder.is_equal(a.expr_fnord(), b.expr_fnord())) .collect::, ZKVMError>>()? .into_iter() .unzip(); let sum_expr = is_equal_per_limb .iter() - .fold(Expression::ZERO, |acc, flag| acc.clone() + flag.expr()); + .fold(Expression::ZERO, |acc, flag| { + acc.clone() + flag.expr_fnord() + }); let sum_flag = create_witin_from_expr!(|| "sum_flag", circuit_builder, false, sum_expr)?; let (is_equal, diff_inv) = - circuit_builder.is_equal(sum_flag.expr(), Expression::from(n_limbs))?; + circuit_builder.is_equal(sum_flag.expr_fnord(), Expression::from(n_limbs))?; Ok(IsEqualConfig { is_equal_per_limb, diff_inv_per_limb, @@ -304,16 +309,16 @@ impl UIntLimbs { E: ExtensionField, { let high_limb_no_msb = circuit_builder.create_witin(|| "high_limb_mask")?; - let high_limb = self.limbs[Self::NUM_LIMBS - 1].expr(); + let high_limb = self.limbs[Self::NUM_LIMBS - 1].expr_fnord(); circuit_builder.lookup_and_byte( high_limb.clone(), Expression::from(0b0111_1111), - high_limb_no_msb.expr(), + high_limb_no_msb.expr_fnord(), )?; let inv_128 = F::from(128).invert().unwrap(); - let msb = (high_limb - high_limb_no_msb.expr()) * Expression::Constant(inv_128); + let msb = (high_limb - high_limb_no_msb.expr_fnord()) * Expression::Constant(inv_128); let msb = create_witin_from_expr!(|| "msb", circuit_builder, false, msb)?; Ok(MsbConfig { msb, @@ -350,7 +355,7 @@ impl UIntLimbs { .iter() .rev() .scan(Expression::from(0), |state, idx| { - *state = state.clone() + idx.expr(); + *state = state.clone() + idx.expr_fnord(); Some(state.clone()) }) .collect(); @@ -371,7 +376,8 @@ impl UIntLimbs { .try_for_each(|(i, ((flag, a), b))| { circuit_builder.require_zero( || format!("byte diff {i} zero check"), - a.expr() - b.expr() - flag.expr() * a.expr() + flag.expr() * b.expr(), + a.expr_fnord() - b.expr_fnord() - flag.expr_fnord() * a.expr_fnord() + + flag.expr_fnord() * b.expr_fnord(), ) })?; @@ -382,14 +388,14 @@ impl UIntLimbs { .iter() .zip_eq(indexes.iter()) .fold(Expression::from(0), |acc, (ai, idx)| { - acc.clone() + ai.expr() * idx.expr() + acc.clone() + ai.expr_fnord() * idx.expr_fnord() }); let sb = rhs .limbs .iter() .zip_eq(indexes.iter()) .fold(Expression::from(0), |acc, (bi, idx)| { - acc.clone() + bi.expr() * idx.expr() + acc.clone() + bi.expr_fnord() * idx.expr_fnord() }); // check the first byte difference has a inverse @@ -401,14 +407,18 @@ impl UIntLimbs { let index_ne = si.first().unwrap(); circuit_builder.require_zero( || "byte inverse check", - lhs_ne_byte.expr() * byte_diff_inv.expr() - - rhs_ne_byte.expr() * byte_diff_inv.expr() - - index_ne.expr(), + lhs_ne_byte.expr_fnord() * byte_diff_inv.expr_fnord() + - rhs_ne_byte.expr_fnord() * byte_diff_inv.expr_fnord() + - index_ne.expr_fnord(), )?; let is_ltu = circuit_builder.create_witin(|| "is_ltu")?; // now we know the first non-equal byte pairs is (lhs_ne_byte, rhs_ne_byte) - circuit_builder.lookup_ltu_byte(lhs_ne_byte.expr(), rhs_ne_byte.expr(), is_ltu.expr())?; + circuit_builder.lookup_ltu_byte( + lhs_ne_byte.expr_fnord(), + rhs_ne_byte.expr_fnord(), + is_ltu.expr_fnord(), + )?; Ok(UIntLtuConfig { byte_diff_inv, indexes, @@ -442,12 +452,12 @@ impl UIntLimbs { // (2) compute $lt(a,b)=a_s\cdot (1-b_s)+eq(a_s,b_s)\cdot ltu(a_{ = witness_values.iter().map(|&w| w.into()).collect_vec(); - uint_c.expr().iter().zip(result).for_each(|(c, ret)| { + uint_c.expr_fnord().iter().zip(result).for_each(|(c, ret)| { assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); }); // overflow if overflow { - let carries = uint_c.carries.unwrap().last().unwrap().expr(); + let carries = uint_c.carries.unwrap().last().unwrap().expr_fnord(); assert_eq!(eval_by_expr(&wit, &challenges, &carries), E::ONE); } else { // non-overflow case, the len of carries should be (NUM_CELLS - 1) @@ -831,13 +841,13 @@ mod tests { // verify let wit: Vec = witness_values.iter().map(|&w| w.into()).collect_vec(); - uint_c.expr().iter().zip(result).for_each(|(c, ret)| { + uint_c.expr_fnord().iter().zip(result).for_each(|(c, ret)| { assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); }); // overflow if overflow { - let overflow = uint_c.carries.unwrap().last().unwrap().expr(); + let overflow = uint_c.carries.unwrap().last().unwrap().expr_fnord(); assert_eq!(eval_by_expr(&wit, &challenges, &overflow), E::ONE); } else { // non-overflow case, the len of carries should be (NUM_CELLS - 1) diff --git a/ceno_zkvm/src/uint/logic.rs b/ceno_zkvm/src/uint/logic.rs index b340df982..f4f8150e9 100644 --- a/ceno_zkvm/src/uint/logic.rs +++ b/ceno_zkvm/src/uint/logic.rs @@ -19,7 +19,12 @@ impl UIntLimbs { c: &Self, ) -> Result<(), ZKVMError> { for (a_byte, b_byte, c_byte) in izip!(a.limbs.iter(), b.limbs.iter(), c.limbs.iter()) { - cb.logic_u8(rom_type, a_byte.expr(), b_byte.expr(), c_byte.expr())?; + cb.logic_u8( + rom_type, + a_byte.expr_fnord(), + b_byte.expr_fnord(), + c_byte.expr_fnord(), + )?; } Ok(()) } diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 4019f2d22..03f6dbd40 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -202,8 +202,8 @@ mod tests { let mut virtual_polys = VirtualPolynomials::new(1, 0); // 3xy + 2y - let expr: Expression = - Expression::from(3) * x.expr() * y.expr() + Expression::from(2) * y.expr(); + let expr: Expression = Expression::from(3) * x.expr_fnord() * y.expr_fnord() + + Expression::from(2) * y.expr_fnord(); let distrinct_zerocheck_terms_set = virtual_polys.add_mle_list_by_expr( None, @@ -216,7 +216,8 @@ mod tests { assert!(virtual_polys.degree() == 2); // 3x^3 - let expr: Expression = Expression::from(3) * x.expr() * x.expr() * x.expr(); + let expr: Expression = + Expression::from(3) * x.expr_fnord() * x.expr_fnord() * x.expr_fnord(); let distrinct_zerocheck_terms_set = virtual_polys.add_mle_list_by_expr( None, wits_in.iter().collect_vec(), From 5fcf5ab7bf10d7b0427533621d8882f9ad91f522 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Mon, 28 Oct 2024 10:00:44 +0800 Subject: [PATCH 2/9] Clean up --- ceno_zkvm/src/chip_handler/general.rs | 4 +- ceno_zkvm/src/expression.rs | 65 ++++++++++++++++++--------- 2 files changed, 47 insertions(+), 22 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 1e068e1ce..0fa058118 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -370,11 +370,11 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.require_zero( || "is equal", - is_eq.expr_fnord().clone() * lhs.clone() - is_eq.expr_fnord() * rhs.clone(), + is_eq.expr_fnord() * &lhs - is_eq.expr_fnord() * &rhs, )?; self.require_zero( || "is equal", - Expression::from(1) - is_eq.expr_fnord().clone() - diff_inverse.expr_fnord() * lhs + 1 - is_eq.expr_fnord() - diff_inverse.expr_fnord() * lhs + diff_inverse.expr_fnord() * rhs, )?; diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index ddca37206..e197a2836 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -442,6 +442,37 @@ impl Sub for Expression { } } +macro_rules! ref_binop_instances { + ($op: ident, $fun: ident) => { + impl $op<&Expression> for Expression { + type Output = Expression; + + fn $fun(self, rhs: &Expression) -> Expression { + self.$fun(rhs.clone()) + } + } + + impl $op> for &Expression { + type Output = Expression; + + fn $fun(self, rhs: Expression) -> Expression { + self.clone().$fun(rhs) + } + } + + impl $op<&Expression> for &Expression { + type Output = Expression; + + fn $fun(self, rhs: &Expression) -> Expression { + self.clone().$fun(rhs.clone()) + } + } + }; +} +ref_binop_instances!(Add, add); +ref_binop_instances!(Sub, sub); +ref_binop_instances!(Mul, mul); + macro_rules! binop_instances { ($op: ident, $fun: ident, ($($t:ty),*)) => { $(impl $op> for $t { @@ -908,8 +939,7 @@ mod tests { // scaledsum * challenge // 3 * x + 2 - let expr: Expression = Into::>::into(3usize) * x.expr_fnord() - + Into::>::into(2usize); + let expr: Expression = 3 * x.expr_fnord() + 2; // c^3 + 1 let c = Expression::Challenge(0, 3, 1.into(), 1.into()); // res @@ -925,7 +955,7 @@ mod tests { // constant * witin // 3 * x - let expr: Expression = Into::>::into(3usize) * x.expr_fnord(); + let expr: Expression = 3 * x.expr_fnord(); assert_eq!( expr, Expression::ScaledSum( @@ -975,36 +1005,32 @@ mod tests { let z = cb.create_witin(|| "z").unwrap(); // scaledsum * challenge // 3 * x + 2 - let expr: Expression = Into::>::into(3usize) * x.expr_fnord() - + Into::>::into(2usize); + let expr: Expression = 3 * x.expr_fnord() + 2; assert!(expr.is_monomial_form()); // 2 product term - let expr: Expression = - Into::>::into(3usize) * x.expr_fnord() * y.expr_fnord() - + Into::>::into(2usize) * x.expr_fnord(); + let expr: Expression = 3 * x.expr_fnord() * y.expr_fnord() + 2 * x.expr_fnord(); assert!(expr.is_monomial_form()); // complex linear operation // (2c + 3) * x * y - 6z - let expr: Expression = - Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr_fnord() * y.expr_fnord() - - Into::>::into(6usize) * z.expr_fnord(); + let expr: Expression = Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) + * x.expr_fnord() + * y.expr_fnord() + - 6 * z.expr_fnord(); assert!(expr.is_monomial_form()); // complex linear operation // (2c + 3) * x * y - 6z - let expr: Expression = - Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr_fnord() * y.expr_fnord() - - Into::>::into(6usize) * z.expr_fnord(); + let expr: Expression = Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) + * x.expr_fnord() + * y.expr_fnord() + - 6 * z.expr_fnord(); assert!(expr.is_monomial_form()); // complex linear operation // (2 * x + 3) * 3 + 6 * 8 - let expr: Expression = (Into::>::into(2usize) * x.expr_fnord() - + Into::>::into(3usize)) - * Into::>::into(3usize) - + Into::>::into(6usize) * Into::>::into(8usize); + let expr: Expression = (2 * x.expr_fnord() + 3) * 3 + 6 * 8; assert!(expr.is_monomial_form()); } @@ -1017,8 +1043,7 @@ mod tests { let y = cb.create_witin(|| "y").unwrap(); // scaledsum * challenge // (x + 1) * (y + 1) - let expr: Expression = (Into::>::into(1usize) + x.expr_fnord()) - * (Into::>::into(2usize) + y.expr_fnord()); + let expr: Expression = (1 + x.expr_fnord()) * (2 + y.expr_fnord()); assert!(!expr.is_monomial_form()); } From b59839657007801acefcda48b3c1a8b6fe2807fc Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Mon, 28 Oct 2024 10:28:31 +0800 Subject: [PATCH 3/9] More cleanup --- ceno_zkvm/src/chip_handler/general.rs | 7 +- ceno_zkvm/src/expression.rs | 84 +++++++++++++++++-- ceno_zkvm/src/expression/monomial.rs | 4 +- ceno_zkvm/src/gadgets/is_lt.rs | 2 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 6 +- .../src/instructions/riscv/memory/gadget.rs | 7 +- ceno_zkvm/src/scheme/utils.rs | 2 +- ceno_zkvm/src/uint.rs | 15 ++-- ceno_zkvm/src/uint/arithmetic.rs | 4 +- 9 files changed, 98 insertions(+), 33 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 0fa058118..f94c3e9f2 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -211,7 +211,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { when_true: &Expression, when_false: &Expression, ) -> Expression { - cond.clone() * when_true.clone() + (1 - cond.clone()) * when_false.clone() + cond * when_true + (1 - cond) * when_false } pub(crate) fn assert_ux( @@ -297,10 +297,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { { self.namespace( || "assert_bit", - |cb| { - cb.cs - .require_zero(name_fn, expr.clone() * (Expression::ONE - expr)) - }, + |cb| cb.cs.require_zero(name_fn, &expr * (1 - &expr)), ) } diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index e197a2836..6e10957a1 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -5,7 +5,7 @@ use std::{ fmt::Display, iter::Sum, mem::MaybeUninit, - ops::{Add, Deref, Mul, Neg, Sub}, + ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Sub, SubAssign}, }; use ff::Field; @@ -315,6 +315,36 @@ impl Add for Expression { } } +impl AddAssign for Expression +where + Expression: Add>, +{ + fn add_assign(&mut self, rhs: Rhs) { + // TODO: consider in-place? + *self = self.clone() + rhs; + } +} + +impl SubAssign for Expression +where + Expression: Sub>, +{ + fn sub_assign(&mut self, rhs: Rhs) { + // TODO: consider in-place? + *self = self.clone() - rhs; + } +} + +impl MulAssign for Expression +where + Expression: Mul>, +{ + fn mul_assign(&mut self, rhs: Rhs) { + // TODO: consider in-place? + *self = self.clone() * rhs; + } +} + impl Sum for Expression { fn sum>>(iter: I) -> Expression { iter.fold(Expression::ZERO, |acc, x| acc + x) @@ -467,13 +497,38 @@ macro_rules! ref_binop_instances { self.clone().$fun(rhs.clone()) } } + + // for mutable references + impl $op<&mut Expression> for Expression { + type Output = Expression; + + fn $fun(self, rhs: &mut Expression) -> Expression { + self.$fun(rhs.clone()) + } + } + + impl $op> for &mut Expression { + type Output = Expression; + + fn $fun(self, rhs: Expression) -> Expression { + self.clone().$fun(rhs) + } + } + + impl $op<&mut Expression> for &mut Expression { + type Output = Expression; + + fn $fun(self, rhs: &mut Expression) -> Expression { + self.clone().$fun(rhs.clone()) + } + } }; } ref_binop_instances!(Add, add); ref_binop_instances!(Sub, sub); ref_binop_instances!(Mul, mul); -macro_rules! binop_instances { +macro_rules! mixed_binop_instances { ($op: ident, $fun: ident, ($($t:ty),*)) => { $(impl $op> for $t { type Output = Expression; @@ -489,21 +544,38 @@ macro_rules! binop_instances { fn $fun(self, rhs: $t) -> Expression { self.$fun(Expression::::from(rhs)) } - })* + } + + impl $op<&Expression> for $t { + type Output = Expression; + + fn $fun(self, rhs: &Expression) -> Expression { + Expression::::from(self).$fun(rhs) + } + } + + impl $op<$t> for &Expression { + type Output = Expression; + + fn $fun(self, rhs: $t) -> Expression { + self.$fun(Expression::::from(rhs)) + } + } + )* }; } -binop_instances!( +mixed_binop_instances!( Add, add, (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) ); -binop_instances!( +mixed_binop_instances!( Sub, sub, (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) ); -binop_instances!( +mixed_binop_instances!( Mul, mul, (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) diff --git a/ceno_zkvm/src/expression/monomial.rs b/ceno_zkvm/src/expression/monomial.rs index fa030595f..1b18de534 100644 --- a/ceno_zkvm/src/expression/monomial.rs +++ b/ceno_zkvm/src/expression/monomial.rs @@ -39,7 +39,7 @@ impl Expression { for a in a { for b in &b { res.push(Term { - coeff: a.coeff.clone() * b.coeff.clone(), + coeff: &a.coeff * &b.coeff, vars: a.vars.iter().chain(b.vars.iter()).cloned().collect(), }); } @@ -54,7 +54,7 @@ impl Expression { for x in x { for a in &a { res.push(Term { - coeff: x.coeff.clone() * a.coeff.clone(), + coeff: &x.coeff * &a.coeff, vars: x.vars.iter().chain(a.vars.iter()).cloned().collect(), }); } diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index 84396602e..fc9d44eb8 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -340,7 +340,7 @@ impl InnerSignedLtConfig { // Convert two's complement representation into field arithmetic. // Example: 0xFFFF_FFFF = 2^32 - 1 --> shift --> -1 let neg_shift = -Expression::Constant((1_u64 << 32).into()); - let lhs_value = lhs.value() + is_lhs_neg.expr() * neg_shift.clone(); + let lhs_value = lhs.value() + is_lhs_neg.expr() * &neg_shift; let rhs_value = rhs.value() + is_rhs_neg.expr() * neg_shift; let config = InnerLtConfig::construct_circuit( diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 3b8dfff24..79f59887f 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -381,13 +381,13 @@ impl MemAddr { /// Represent the address aligned to 2 bytes. pub fn expr_align2(&self) -> AddressExpr { - self.addr.address_expr() - self.low_bit_exprs()[0].clone() + self.addr.address_expr() - &self.low_bit_exprs()[0] } /// Represent the address aligned to 4 bytes. pub fn expr_align4(&self) -> AddressExpr { let low_bits = self.low_bit_exprs(); - self.addr.address_expr() - low_bits[1].clone() * 2 - low_bits[0].clone() + self.addr.address_expr() - &low_bits[1] * 2 - &low_bits[0] } /// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1]. @@ -425,7 +425,7 @@ impl MemAddr { .invert() .unwrap() .expr_fnord(); - let mid_u14 = (limbs[0].clone() - low_sum) * shift_right; + let mid_u14 = (&limbs[0] - low_sum) * shift_right; cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?; // Range check the high limb. diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 603d2a60f..cce04f20e 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -80,8 +80,7 @@ impl MemWordChange { let u8_base_inv = E::BaseField::from(1 << 8).invert().unwrap(); cb.assert_ux::<_, _, 8>( || "rs2_limb[0].le_bytes[1]", - u8_base_inv.expr_fnord() - * (rs2_limbs[0].clone() - rs2_limb_bytes[0].expr_fnord()), + u8_base_inv.expr_fnord() * (&rs2_limbs[0] - rs2_limb_bytes[0].expr_fnord()), )?; // alloc a new witIn to cache degree 2 expression @@ -126,8 +125,8 @@ impl MemWordChange { // degree 2 expression low_bits[1].clone(), expected_change.expr_fnord(), - (1 << 16) * (rs2_limbs[0].clone() - prev_limbs[1].clone()), - rs2_limbs[0].clone() - prev_limbs[0].clone(), + (1 << 16) * (&rs2_limbs[0] - &prev_limbs[1]), + &rs2_limbs[0] - &prev_limbs[0], )?; Ok(MemWordChange { diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 9058e69e8..7588978ad 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -449,7 +449,7 @@ mod tests { let expected_final_product: E = last_layer .iter() .map(|f| match f.evaluations() { - FieldType::Ext(e) => e.iter().cloned().reduce(|a, b| a * b).unwrap(), + FieldType::Ext(e) => e.iter().copied().reduce(|a, b| a * b).unwrap(), _ => unreachable!(""), }) .product(); diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 0735baf98..035da8520 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -170,10 +170,7 @@ impl UIntLimbs { .assert_ux::<_, _, C>(|| "range check", w.expr_fnord()) .unwrap(); circuit_builder - .require_zero( - || "create_witin_from_expr", - w.expr_fnord() - expr_limbs[i].clone(), - ) + .require_zero(|| "create_witin_from_expr", w.expr_fnord() - &expr_limbs[i]) .unwrap(); w }) @@ -300,7 +297,7 @@ impl UIntLimbs { chunk .iter() .zip(shift_pows.iter()) - .map(|(limb, shift)| shift.clone() * limb.expr_fnord()) + .map(|(limb, shift)| shift * limb.expr_fnord()) .reduce(|a, b| a + b) .unwrap() }) @@ -318,7 +315,7 @@ impl UIntLimbs { let shift_pows = { let mut shift_pows = Vec::with_capacity(k); shift_pows.push(Expression::Constant(E::BaseField::ONE)); - (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap().clone() * (1 << 8))); + (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap() * (1 << 8))); shift_pows }; let split_limbs = x @@ -335,7 +332,7 @@ impl UIntLimbs { let combined_limb = limbs .iter() .zip(shift_pows.iter()) - .map(|(limb, shift)| shift.clone() * limb.clone()) + .map(|(limb, shift)| shift * limb) .reduce(|a, b| a + b) .unwrap(); @@ -514,7 +511,7 @@ impl UIntLimbs { self.expr_fnord() .into_iter() .rev() - .reduce(|sum, limb| sum * base.clone() + limb) + .reduce(|sum, limb| sum * &base + limb) .unwrap() } @@ -602,7 +599,7 @@ impl UIntLimbs<32, 8, E> { let u16_limbs = u8_limbs .chunks(2) .map(|chunk| { - let (a, b) = (chunk[0].clone(), chunk[1].clone()); + let (a, b) = (&chunk[0], &chunk[1]); a + b * 256 }) .collect_vec(); diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index d1d4c283f..7ab971c23 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -182,9 +182,9 @@ impl UIntLimbs { let idx = i + j; if idx < c_limbs.len() { if result_c.get(idx).is_none() { - result_c.push(a.clone() * b.clone()); + result_c.push(a * b); } else { - result_c[idx] = result_c[idx].clone() + a.clone() * b.clone(); + result_c[idx] += a * b; } } }); From 8cffbe84376a84558def6c03782af929abcc2156 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Mon, 28 Oct 2024 15:50:40 +0800 Subject: [PATCH 4/9] format --- ceno_zkvm/src/chip_handler/general.rs | 8 ++------ ceno_zkvm/src/expression.rs | 14 ++++++-------- ceno_zkvm/src/gadgets/is_zero.rs | 5 +---- ceno_zkvm/src/instructions/riscv/b_insn.rs | 5 ++--- ceno_zkvm/src/scheme/mock_prover.rs | 6 ++---- ceno_zkvm/src/scheme/utils.rs | 5 +---- ceno_zkvm/src/tables/program.rs | 7 +------ ceno_zkvm/src/uint.rs | 5 +---- ceno_zkvm/src/uint/arithmetic.rs | 22 ++++++---------------- ceno_zkvm/src/uint/logic.rs | 7 +------ ceno_zkvm/src/virtual_polys.rs | 7 +++---- 11 files changed, 26 insertions(+), 65 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 27ce14ab7..f8150db78 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -365,14 +365,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { let is_eq = self.create_witin(|| "is_eq")?; let diff_inverse = self.create_witin(|| "diff_inverse")?; + self.require_zero(|| "is equal", is_eq.expr() * &lhs - is_eq.expr() * &rhs)?; self.require_zero( || "is equal", - is_eq.expr() * &lhs - is_eq.expr() * &rhs, - )?; - self.require_zero( - || "is equal", - 1 - is_eq.expr() - diff_inverse.expr() * lhs - + diff_inverse.expr() * rhs, + 1 - is_eq.expr() - diff_inverse.expr() * lhs + diff_inverse.expr() * rhs, )?; Ok((is_eq, diff_inverse)) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index dbf19ba34..a47ef90a3 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -1070,18 +1070,16 @@ mod tests { // complex linear operation // (2c + 3) * x * y - 6z - let expr: Expression = Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) - * x.expr() - * y.expr() - - 6 * z.expr(); + let expr: Expression = + Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) * x.expr() * y.expr() + - 6 * z.expr(); assert!(expr.is_monomial_form()); // complex linear operation // (2c + 3) * x * y - 6z - let expr: Expression = Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) - * x.expr() - * y.expr() - - 6 * z.expr(); + let expr: Expression = + Expression::Challenge(0, 1, 2_u64.into(), 3_u64.into()) * x.expr() * y.expr() + - 6 * z.expr(); assert!(expr.is_monomial_form()); // complex linear operation diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs index 4a883349c..02994e4f7 100644 --- a/ceno_zkvm/src/gadgets/is_zero.rs +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -30,10 +30,7 @@ impl IsZeroConfig { let inverse = cb.create_witin(|| "inv")?; // x==0 => is_zero=1 - cb.require_one( - || "is_zero_1", - is_zero.expr() + x.clone() * inverse.expr(), - )?; + cb.require_one(|| "is_zero_1", is_zero.expr() + x.clone() * inverse.expr())?; // x!=0 => is_zero=0 cb.require_zero(|| "is_zero_0", is_zero.expr() * x.clone())?; diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 4a57e5493..b7c74543f 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -71,9 +71,8 @@ impl BInstructionConfig { ))?; // Branch program counter - let pc_offset = branch_taken_bit.clone() * imm.expr() - - branch_taken_bit * PC_STEP_SIZE - + PC_STEP_SIZE; + let pc_offset = + branch_taken_bit.clone() * imm.expr() - branch_taken_bit * PC_STEP_SIZE + PC_STEP_SIZE; let next_pc = vm_state.next_pc.unwrap(); circuit_builder.require_equal( || "pc_branch", diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 38e6bf0a0..6c325573a 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -849,8 +849,7 @@ mod tests { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a")?; let b = cb.create_witin(|| "b")?; - let lt_wtns = - AssertLTConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; + let lt_wtns = AssertLTConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; Ok(Self { a, b, lt_wtns }) } @@ -974,8 +973,7 @@ mod tests { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a")?; let b = cb.create_witin(|| "b")?; - let lt_wtns = - IsLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; + let lt_wtns = IsLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; Ok(Self { a, b, lt_wtns }) } diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index d2f8aa0a9..9a80cf758 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -677,10 +677,7 @@ mod tests { let b = cb.create_witin(|| "b").unwrap(); let c = cb.create_witin(|| "c").unwrap(); - let expr: Expression = a.expr() - + b.expr() - + a.expr() * b.expr() - + (c.expr() * 3 + 2); + let expr: Expression = a.expr() + b.expr() + a.expr() * b.expr() + (c.expr() * 3 + 2); let res = wit_infer_by_expr( &[], diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index e0343843f..3514365c8 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -143,12 +143,7 @@ impl TableCircuit cb.rlc_chip_record(fields) }; - cb.lk_table_record( - || "prog table", - PROGRAM_SIZE, - record_exprs, - mlt.expr(), - )?; + cb.lk_table_record(|| "prog table", PROGRAM_SIZE, record_exprs, mlt.expr())?; Ok(ProgramTableConfig { record, mlt }) } diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index ba8efebf3..498bc626a 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -92,10 +92,7 @@ impl UIntLimbs { .map(|i| { let w = cb.create_witin(|| format!("limb_{i}"))?; if is_check { - cb.assert_ux::<_, _, C>( - || format!("limb_{i}_in_{C}"), - w.expr(), - )?; + cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; } // skip range check Ok(w) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 6266fb662..1f16c2d89 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -54,8 +54,7 @@ impl UIntLimbs { limb_expr = limb_expr.clone() + carry.unwrap().expr(); } if next_carry.is_some() { - limb_expr = - limb_expr.clone() - next_carry.unwrap().expr() * Self::POW_OF_C; + limb_expr = limb_expr.clone() - next_carry.unwrap().expr() * Self::POW_OF_C; } circuit_builder @@ -121,8 +120,7 @@ impl UIntLimbs { // with high limb, overall cell will be double let c_limbs: Vec = (0..num_limbs).try_fold(vec![], |mut c_limbs, i| { let limb = circuit_builder.create_witin(|| format!("limb_{i}"))?; - circuit_builder - .assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb.expr())?; + circuit_builder.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb.expr())?; c_limbs.push(limb); Result::, ZKVMError>::Ok(c_limbs) })?; @@ -199,8 +197,7 @@ impl UIntLimbs { result_c[i] = result_c[i].clone() + carry.unwrap().expr(); } if next_carry.is_some() { - result_c[i] = - result_c[i].clone() - next_carry.unwrap().expr() * Self::POW_OF_C; + result_c[i] = result_c[i].clone() - next_carry.unwrap().expr() * Self::POW_OF_C; } circuit_builder.require_zero(|| format!("mul_zero_{i}"), result_c[i].clone())?; Ok::<(), ZKVMError>(()) @@ -281,9 +278,7 @@ impl UIntLimbs { let sum_expr = is_equal_per_limb .iter() - .fold(Expression::ZERO, |acc, flag| { - acc.clone() + flag.expr() - }); + .fold(Expression::ZERO, |acc, flag| acc.clone() + flag.expr()); let sum_flag = WitIn::from_expr(|| "sum_flag", circuit_builder, sum_expr, false)?; let (is_equal, diff_inv) = @@ -375,8 +370,7 @@ impl UIntLimbs { .try_for_each(|(i, ((flag, a), b))| { circuit_builder.require_zero( || format!("byte diff {i} zero check"), - a.expr() - b.expr() - flag.expr() * a.expr() - + flag.expr() * b.expr(), + a.expr() - b.expr() - flag.expr() * a.expr() + flag.expr() * b.expr(), ) })?; @@ -411,11 +405,7 @@ impl UIntLimbs { let is_ltu = circuit_builder.create_witin(|| "is_ltu")?; // now we know the first non-equal byte pairs is (lhs_ne_byte, rhs_ne_byte) - circuit_builder.lookup_ltu_byte( - lhs_ne_byte.expr(), - rhs_ne_byte.expr(), - is_ltu.expr(), - )?; + circuit_builder.lookup_ltu_byte(lhs_ne_byte.expr(), rhs_ne_byte.expr(), is_ltu.expr())?; Ok(UIntLtuConfig { byte_diff_inv, indexes, diff --git a/ceno_zkvm/src/uint/logic.rs b/ceno_zkvm/src/uint/logic.rs index b6891668f..b340df982 100644 --- a/ceno_zkvm/src/uint/logic.rs +++ b/ceno_zkvm/src/uint/logic.rs @@ -19,12 +19,7 @@ impl UIntLimbs { c: &Self, ) -> Result<(), ZKVMError> { for (a_byte, b_byte, c_byte) in izip!(a.limbs.iter(), b.limbs.iter(), c.limbs.iter()) { - cb.logic_u8( - rom_type, - a_byte.expr(), - b_byte.expr(), - c_byte.expr(), - )?; + cb.logic_u8(rom_type, a_byte.expr(), b_byte.expr(), c_byte.expr())?; } Ok(()) } diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index f79d77fd1..4019f2d22 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -202,8 +202,8 @@ mod tests { let mut virtual_polys = VirtualPolynomials::new(1, 0); // 3xy + 2y - let expr: Expression = Expression::from(3) * x.expr() * y.expr() - + Expression::from(2) * y.expr(); + let expr: Expression = + Expression::from(3) * x.expr() * y.expr() + Expression::from(2) * y.expr(); let distrinct_zerocheck_terms_set = virtual_polys.add_mle_list_by_expr( None, @@ -216,8 +216,7 @@ mod tests { assert!(virtual_polys.degree() == 2); // 3x^3 - let expr: Expression = - Expression::from(3) * x.expr() * x.expr() * x.expr(); + let expr: Expression = Expression::from(3) * x.expr() * x.expr() * x.expr(); let distrinct_zerocheck_terms_set = virtual_polys.add_mle_list_by_expr( None, wits_in.iter().collect_vec(), From f7e20a6ad989e1bb790f520df9d8a1f43e032525 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Mon, 28 Oct 2024 15:57:05 +0800 Subject: [PATCH 5/9] Fix --- ceno_zkvm/src/expression.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index a47ef90a3..c4e0d6713 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -794,7 +794,7 @@ macro_rules! impl_from_via_ToExpr { $( impl From<$t> for Expression { fn from(value: $t) -> Self { - value.expr_fnord() + value.expr() } } )* From 9018237417c96b14af27fc07ae8bd8105de2bf1e Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Mon, 28 Oct 2024 16:05:01 +0800 Subject: [PATCH 6/9] Macros --- ceno_zkvm/src/expression.rs | 42 +++++++++++++------------------------ 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index c4e0d6713..c8af6de80 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -315,35 +315,23 @@ impl Add for Expression { } } -impl AddAssign for Expression -where - Expression: Add>, -{ - fn add_assign(&mut self, rhs: Rhs) { - // TODO: consider in-place? - *self = self.clone() + rhs; - } -} - -impl SubAssign for Expression -where - Expression: Sub>, -{ - fn sub_assign(&mut self, rhs: Rhs) { - // TODO: consider in-place? - *self = self.clone() - rhs; - } +macro_rules! binop_assign_instances { + ($op_assign: ident, $fun_assign: ident, $op: ident, $fun: ident) => { + impl $op_assign for Expression + where + Expression: $op>, + { + fn $fun_assign(&mut self, rhs: Rhs) { + // TODO: consider in-place? + *self = self.clone().$fun(rhs); + } + } + }; } -impl MulAssign for Expression -where - Expression: Mul>, -{ - fn mul_assign(&mut self, rhs: Rhs) { - // TODO: consider in-place? - *self = self.clone() * rhs; - } -} +binop_assign_instances!(AddAssign, add_assign, Add, add); +binop_assign_instances!(SubAssign, sub_assign, Sub, sub); +binop_assign_instances!(MulAssign, mul_assign, Mul, mul); impl Sum for Expression { fn sum>>(iter: I) -> Expression { From 8c2e4e7c1aa45a951cc7014bece54697296cf3c4 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Mon, 28 Oct 2024 16:07:41 +0800 Subject: [PATCH 7/9] Convert more --- ceno_zkvm/src/chip_handler/general.rs | 5 +---- ceno_zkvm/src/uint.rs | 3 +-- ceno_zkvm/src/virtual_polys.rs | 5 ++--- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index f8150db78..a4765bf97 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -176,10 +176,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { - self.namespace( - || "require_one", - |cb| cb.cs.require_zero(name_fn, Expression::from(1) - expr), - ) + self.namespace(|| "require_one", |cb| cb.cs.require_zero(name_fn, 1 - expr)) } pub fn condition_require_equal( diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 498bc626a..22534e8d4 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -504,11 +504,10 @@ impl UIntLimbs { /// Get an Expression from the limbs, unsafe if Uint value exceeds field limit pub fn value(&self) -> Expression { - let base = Expression::from(1 << C); self.expr() .into_iter() .rev() - .reduce(|sum, limb| sum * &base + limb) + .reduce(|sum, limb| sum * (1 << C) + limb) .unwrap() } diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 4019f2d22..cca66867d 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -202,8 +202,7 @@ mod tests { let mut virtual_polys = VirtualPolynomials::new(1, 0); // 3xy + 2y - let expr: Expression = - Expression::from(3) * x.expr() * y.expr() + Expression::from(2) * y.expr(); + let expr: Expression = 3 * x.expr() * y.expr() + 2 * y.expr(); let distrinct_zerocheck_terms_set = virtual_polys.add_mle_list_by_expr( None, @@ -216,7 +215,7 @@ mod tests { assert!(virtual_polys.degree() == 2); // 3x^3 - let expr: Expression = Expression::from(3) * x.expr() * x.expr() * x.expr(); + let expr: Expression = 3 * x.expr() * x.expr() * x.expr(); let distrinct_zerocheck_terms_set = virtual_polys.add_mle_list_by_expr( None, wits_in.iter().collect_vec(), From c9c564a118956164dde75e89eb3ca7446022205f Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Mon, 28 Oct 2024 16:09:39 +0800 Subject: [PATCH 8/9] Doc --- ceno_zkvm/src/expression.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index c8af6de80..d13f56397 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -460,6 +460,7 @@ impl Sub for Expression { } } +/// Instances that Expression and &Expression macro_rules! ref_binop_instances { ($op: ident, $fun: ident) => { impl $op<&Expression> for Expression { From 1d7e7340e83c44f8056267e4bfe02fa4bde29473 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Mon, 28 Oct 2024 16:10:03 +0800 Subject: [PATCH 9/9] Fix --- ceno_zkvm/src/expression.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index d13f56397..6983b820d 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -460,7 +460,7 @@ impl Sub for Expression { } } -/// Instances that Expression and &Expression +/// Instances for binary operations that mix Expression and &Expression macro_rules! ref_binop_instances { ($op: ident, $fun: ident) => { impl $op<&Expression> for Expression {