diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index c9f55c20d..750ec7270 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -12,10 +12,7 @@ use crate::{ var::{ConstOrCell, Value, Var}, }; -use super::{FnInfoType, Module}; - -const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Field"; -const CHECK_FIELD_SIZE_FN: &str = "check_field_size(cmp: Field)"; +use super::{builtins::Builtin, FnInfoType, Module}; pub struct BitsLib {} @@ -24,81 +21,95 @@ impl Module for BitsLib { fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (NTH_BIT_FN, nth_bit, false), - (CHECK_FIELD_SIZE_FN, check_field_size, false), + (NthBitFn::SIGNATURE, NthBitFn::builtin, false), + ( + CheckFieldSizeFn::SIGNATURE, + CheckFieldSizeFn::builtin, + false, + ), ] } } -fn nth_bit( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // should be two input vars - assert_eq!(vars.len(), 2); - - // these should be type checked already, unless it is called by other low level functions - // eg. builtins - let var_info = &vars[0]; - let val = &var_info.var; - assert_eq!(val.len(), 1); - - let var_info = &vars[1]; - let nth = &var_info.var; - assert_eq!(nth.len(), 1); - - let nth: usize = match &nth[0] { - ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), - ConstOrCell::Const(cst) => cst.to_u64() as usize, - }; - - let val = match &val[0] { - ConstOrCell::Cell(cvar) => cvar.clone(), - ConstOrCell::Const(cst) => { - // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var - let bit = cst.to_bits(); - return Ok(Some(Var::new_cvar( - ConstOrCell::Const(B::Field::from(bit[nth])), - span, - ))); - } - }; +struct NthBitFn {} +struct CheckFieldSizeFn {} + +impl Builtin for NthBitFn { + const SIGNATURE: &'static str = "nth_bit(val: Field, const nth: Field) -> Field"; + + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // should be two input vars + assert_eq!(vars.len(), 2); + + // these should be type checked already, unless it is called by other low level functions + // eg. builtins + let var_info = &vars[0]; + let val = &var_info.var; + assert_eq!(val.len(), 1); + + let var_info = &vars[1]; + let nth = &var_info.var; + assert_eq!(nth.len(), 1); + + let nth: usize = match &nth[0] { + ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), + ConstOrCell::Const(cst) => cst.to_u64() as usize, + }; + + let val = match &val[0] { + ConstOrCell::Cell(cvar) => cvar.clone(), + ConstOrCell::Const(cst) => { + // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var + let bit = cst.to_bits(); + return Ok(Some(Var::new_cvar( + ConstOrCell::Const(B::Field::from(bit[nth])), + span, + ))); + } + }; - let bit = compiler - .backend - .new_internal_var(Value::NthBit(val.clone(), nth), span); + let bit = compiler + .backend + .new_internal_var(Value::NthBit(val.clone(), nth), span); - Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) + Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) + } } -// Ensure that the field size is not exceeded -fn check_field_size( - _compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - let var = &vars[0].var[0]; - let bit_len = B::Field::size_in_bits() as u64; - - match var { - ConstOrCell::Const(cst) => { - let to_cmp = cst.to_u64(); - if to_cmp >= bit_len { - return Err(Error::new( - "constraint-generation", - ErrorKind::AssertionFailed, - span, - )); +impl Builtin for CheckFieldSizeFn { + const SIGNATURE: &'static str = "check_field_size(cmp: Field)"; + + fn builtin( + _compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + let var = &vars[0].var[0]; + let bit_len = B::Field::size_in_bits() as u64; + + match var { + ConstOrCell::Const(cst) => { + let to_cmp = cst.to_u64(); + if to_cmp >= bit_len { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertionFailed, + span, + )); + } + Ok(None) } - Ok(None) + ConstOrCell::Cell(_) => Err(Error::new( + "constraint-generation", + ErrorKind::ExpectedConstant, + span, + )), } - ConstOrCell::Cell(_) => Err(Error::new( - "constraint-generation", - ErrorKind::ExpectedConstant, - span, - )), } } diff --git a/src/stdlib/builtins.rs b/src/stdlib/builtins.rs index 65e607cdd..793e49a9e 100644 --- a/src/stdlib/builtins.rs +++ b/src/stdlib/builtins.rs @@ -21,10 +21,6 @@ use super::{FnInfoType, Module}; pub const QUALIFIED_BUILTINS: &str = "std/builtins"; pub const BUILTIN_FN_NAMES: [&str; 3] = ["assert", "assert_eq", "log"]; -const ASSERT_FN: &str = "assert(condition: Bool)"; -const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)"; -const LOG_FN: &str = "log(var: Field)"; - pub struct BuiltinsLib {} impl Module for BuiltinsLib { @@ -32,136 +28,165 @@ impl Module for BuiltinsLib { fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (ASSERT_FN, assert_fn, false), - (ASSERT_EQ_FN, assert_eq_fn, false), + (AssertEqFn::SIGNATURE, AssertEqFn::builtin, false), + (AssertFn::SIGNATURE, AssertFn::builtin, false), // true -> skip argument type checking for log - (LOG_FN, log_fn, true), + (LogFn::SIGNATURE, LogFn::builtin, true), ] } } -/// Asserts that two vars are equal. -fn assert_eq_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get two vars - assert_eq!(vars.len(), 2); - let lhs_info = &vars[0]; - let rhs_info = &vars[1]; - - // they are both of type field - if !matches!(lhs_info.typ, Some(TyKind::Field { .. })) { - let lhs = lhs_info.typ.clone().ok_or_else(|| { - Error::new( +pub trait Builtin { + const SIGNATURE: &'static str; + + fn builtin( + compiler: &mut CircuitWriter, + generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>>; +} + +struct AssertEqFn {} +struct AssertFn {} +struct LogFn {} + +impl Builtin for AssertEqFn { + const SIGNATURE: &'static str = "assert_eq(lhs: Field, rhs: Field)"; + + /// Asserts that two vars are equal. + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo<::Field, ::Var>], + span: Span, + ) -> Result::Field, ::Var>>> { + // we get two vars + assert_eq!(vars.len(), 2); + let lhs_info = &vars[0]; + let rhs_info = &vars[1]; + + // they are both of type field + if !matches!(lhs_info.typ, Some(TyKind::Field { .. })) { + let lhs = lhs_info.typ.clone().ok_or_else(|| { + Error::new( + "constraint-generation", + ErrorKind::UnexpectedError("No type info for lhs of assertion"), + span, + ) + })?; + + Err(Error::new( "constraint-generation", - ErrorKind::UnexpectedError("No type info for lhs of assertion"), + ErrorKind::AssertTypeMismatch("rhs", lhs), span, - ) - })?; - - Err(Error::new( - "constraint-generation", - ErrorKind::AssertTypeMismatch("rhs", lhs), - span, - ))? - } + ))? + } + + if !matches!(rhs_info.typ, Some(TyKind::Field { .. })) { + let rhs = rhs_info.typ.clone().ok_or_else(|| { + Error::new( + "constraint-generation", + ErrorKind::UnexpectedError("No type info for rhs of assertion"), + span, + ) + })?; - if !matches!(rhs_info.typ, Some(TyKind::Field { .. })) { - let rhs = rhs_info.typ.clone().ok_or_else(|| { - Error::new( + Err(Error::new( "constraint-generation", - ErrorKind::UnexpectedError("No type info for rhs of assertion"), + ErrorKind::AssertTypeMismatch("rhs", rhs), span, - ) - })?; - - Err(Error::new( - "constraint-generation", - ErrorKind::AssertTypeMismatch("rhs", rhs), - span, - ))? - } - - // retrieve the values - let lhs_var = &lhs_info.var; - assert_eq!(lhs_var.len(), 1); - let lhs_cvar = &lhs_var[0]; + ))? + } - let rhs_var = &rhs_info.var; - assert_eq!(rhs_var.len(), 1); - let rhs_cvar = &rhs_var[0]; + // retrieve the values + let lhs_var = &lhs_info.var; + assert_eq!(lhs_var.len(), 1); + let lhs_cvar = &lhs_var[0]; + + let rhs_var = &rhs_info.var; + assert_eq!(rhs_var.len(), 1); + let rhs_cvar = &rhs_var[0]; + + match (lhs_cvar, rhs_cvar) { + // two constants + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + if a != b { + Err(Error::new( + "constraint-generation", + ErrorKind::AssertionFailed, + span, + ))? + } + } - match (lhs_cvar, rhs_cvar) { - // two constants - (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { - if a != b { - Err(Error::new( - "constraint-generation", - ErrorKind::AssertionFailed, - span, - ))? + // a const and a var + (ConstOrCell::Const(cst), ConstOrCell::Cell(cvar)) + | (ConstOrCell::Cell(cvar), ConstOrCell::Const(cst)) => { + compiler.backend.assert_eq_const(cvar, *cst, span) + } + (ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => { + compiler.backend.assert_eq_var(lhs, rhs, span) } } - // a const and a var - (ConstOrCell::Const(cst), ConstOrCell::Cell(cvar)) - | (ConstOrCell::Cell(cvar), ConstOrCell::Const(cst)) => { - compiler.backend.assert_eq_const(cvar, *cst, span) - } - (ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => { - compiler.backend.assert_eq_var(lhs, rhs, span) - } + Ok(None) } - - Ok(None) } -/// Asserts that a condition is true. -fn assert_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get a single var - assert_eq!(vars.len(), 1); - - // of type bool - let var_info = &vars[0]; - assert!(matches!(var_info.typ, Some(TyKind::Bool))); - - // of only one field element - let var = &var_info.var; - assert_eq!(var.len(), 1); - let cond = &var[0]; - - match cond { - ConstOrCell::Const(cst) => { - assert!(cst.is_one()); - } - ConstOrCell::Cell(cvar) => { - let one = B::Field::one(); - compiler.backend.assert_eq_const(cvar, one, span); +impl Builtin for AssertFn { + const SIGNATURE: &'static str = "assert(condition: Bool)"; + + /// Asserts that a condition is true. + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo<::Field, ::Var>], + span: Span, + ) -> Result::Field, ::Var>>> { + // we get a single var + assert_eq!(vars.len(), 1); + + // of type bool + let var_info = &vars[0]; + assert!(matches!(var_info.typ, Some(TyKind::Bool))); + + // of only one field element + let var = &var_info.var; + assert_eq!(var.len(), 1); + let cond = &var[0]; + + match cond { + ConstOrCell::Const(cst) => { + assert!(cst.is_one()); + } + ConstOrCell::Cell(cvar) => { + let one = B::Field::one(); + compiler.backend.assert_eq_const(cvar, one, span); + } } - } - Ok(None) + Ok(None) + } } -/// Logging -fn log_fn( - compiler: &mut CircuitWriter, - generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - for var in vars { - // todo: will need to support string argument in order to customize msg - compiler.backend.log_var(var, "log".to_owned(), span); - } +impl Builtin for LogFn { + // todo: currently only supports a single field var + // to support all the types, we can bypass the type check for this log function for now + const SIGNATURE: &'static str = "log(var: Field)"; + + /// Logging + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo<::Field, ::Var>], + span: Span, + ) -> Result::Field, ::Var>>> { + for var in vars { + // todo: will need to support string argument in order to customize msg + compiler.backend.log_var(var, "log".to_owned(), span); + } - Ok(None) + Ok(None) + } } diff --git a/src/stdlib/crypto.rs b/src/stdlib/crypto.rs index 66113cddd..13ff91a86 100644 --- a/src/stdlib/crypto.rs +++ b/src/stdlib/crypto.rs @@ -1,14 +1,27 @@ -use super::{FnInfoType, Module}; +use super::{builtins::Builtin, FnInfoType, Module}; use crate::backends::Backend; -const POSEIDON_FN: &str = "poseidon(input: [Field; 2]) -> [Field; 3]"; - pub struct CryptoLib {} impl Module for CryptoLib { const MODULE: &'static str = "crypto"; fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { - vec![(POSEIDON_FN, B::poseidon(), false)] + vec![(PoseidonFn::SIGNATURE, PoseidonFn::builtin, false)] + } +} + +struct PoseidonFn {} + +impl Builtin for PoseidonFn { + const SIGNATURE: &'static str = "poseidon(input: [Field; 2]) -> [Field; 3]"; + + fn builtin( + compiler: &mut crate::circuit_writer::CircuitWriter, + generics: &crate::parser::types::GenericParameters, + vars: &[crate::circuit_writer::VarInfo], + span: crate::constants::Span, + ) -> crate::error::Result>> { + B::poseidon()(compiler, generics, vars, span) } } diff --git a/src/stdlib/int.rs b/src/stdlib/int.rs index 03c574890..94f40ff89 100644 --- a/src/stdlib/int.rs +++ b/src/stdlib/int.rs @@ -11,9 +11,7 @@ use crate::{ var::{ConstOrCell, Value, Var}, }; -use super::{FnInfoType, Module}; - -const DIVMOD_FN: &str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]"; +use super::{builtins::Builtin, FnInfoType, Module}; pub struct IntLib {} @@ -21,57 +19,63 @@ impl Module for IntLib { const MODULE: &'static str = "int"; fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { - vec![(DIVMOD_FN, divmod_fn, false)] + vec![(DivmodFn::SIGNATURE, DivmodFn::builtin, false)] } } /// Divides two field elements and returns the quotient and remainder. -fn divmod_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get two vars - let dividend_info = &vars[0]; - let divisor_info = &vars[1]; - - // retrieve the values - let dividend_var = ÷nd_info.var[0]; - let divisor_var = &divisor_info.var[0]; - - match (dividend_var, divisor_var) { - // two constants - (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { - // convert to bigints - let a = a.to_biguint(); - let b = b.to_biguint(); - - let quotient = a.clone() / b.clone(); - let remainder = a % b; - - // convert back to fields - let quotient = B::Field::from_biguint("ient).unwrap(); - let remainder = B::Field::from_biguint(&remainder).unwrap(); - - Ok(Some(Var::new( - vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)], - span, - ))) - } +struct DivmodFn {} + +impl Builtin for DivmodFn { + const SIGNATURE: &'static str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]"; + + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // we get two vars + let dividend_info = &vars[0]; + let divisor_info = &vars[1]; + + // retrieve the values + let dividend_var = ÷nd_info.var[0]; + let divisor_var = &divisor_info.var[0]; + + match (dividend_var, divisor_var) { + // two constants + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + // convert to bigints + let a = a.to_biguint(); + let b = b.to_biguint(); + + let quotient = a.clone() / b.clone(); + let remainder = a % b; + + // convert back to fields + let quotient = B::Field::from_biguint("ient).unwrap(); + let remainder = B::Field::from_biguint(&remainder).unwrap(); + + Ok(Some(Var::new( + vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)], + span, + ))) + } + + _ => { + let quotient = compiler + .backend + .new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span); + let remainder = compiler + .backend + .new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span); - _ => { - let quotient = compiler - .backend - .new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span); - let remainder = compiler - .backend - .new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span); - - Ok(Some(Var::new( - vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)], - span, - ))) + Ok(Some(Var::new( + vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)], + span, + ))) + } } } }