From f1f925d851d330cd21ebe2251f2eddc7617ab213 Mon Sep 17 00:00:00 2001 From: Shourya Goel Date: Fri, 27 Sep 2024 15:37:16 +0530 Subject: [PATCH] Implement ModBuiltin (#666) * implement partial functionalities and structs required * run fmt * Introduce max func * Added more functions * ran fmt * Add test for modulo * Added structure for integrating builtin with the vm * Added useful comments * Update FillMemory * Added more comments * Add test * Added TODOs * fixes * update loop * Added addmod and mulmod * added some fixes * test passes now * fix error and test * Make tests pass * mod subtraction tests * Added some comments * Added subtraction tests * Added recursive case as well * nit * Add multiplication test * refactor math utils * move one math utils function * All Tests pass * nit * Added mirror functionality * remove redundant test * updated checkwrite and infervalue * nit --- pkg/hintrunner/core/hint.go | 2 +- pkg/hintrunner/utils/math_utils.go | 65 +-- pkg/hintrunner/utils/math_utils_test.go | 66 --- pkg/hintrunner/zero/zerohint_ec.go | 2 +- pkg/hintrunner/zero/zerohint_signature.go | 3 +- pkg/runner/runner_test.go | 22 + pkg/utils/math.go | 74 +++ pkg/utils/math_test.go | 67 +++ pkg/vm/builtins/builtin_runner.go | 14 + pkg/vm/builtins/layouts.go | 4 +- pkg/vm/builtins/modulo.go | 541 ++++++++++++++++++++++ pkg/vm/builtins/modulo_test.go | 158 +++++++ 12 files changed, 884 insertions(+), 134 deletions(-) create mode 100644 pkg/vm/builtins/modulo.go create mode 100644 pkg/vm/builtins/modulo_test.go diff --git a/pkg/hintrunner/core/hint.go b/pkg/hintrunner/core/hint.go index 71b725f33..23b426b38 100644 --- a/pkg/hintrunner/core/hint.go +++ b/pkg/hintrunner/core/hint.go @@ -540,7 +540,7 @@ func (hint U256InvModN) Execute(vm *VM.VirtualMachine, _ *hinter.HintRunnerConte n := new(big.Int).Lsh(&N1BigInt, 128) n.Add(n, &N0BigInt) - _, r, g := u.Igcdex(n, b) + _, r, g := utils.Igcdex(n, b) mask := new(big.Int).Lsh(big.NewInt(1), 128) mask.Sub(mask, big.NewInt(1)) diff --git a/pkg/hintrunner/utils/math_utils.go b/pkg/hintrunner/utils/math_utils.go index 7f274f81d..678aca7d4 100644 --- a/pkg/hintrunner/utils/math_utils.go +++ b/pkg/hintrunner/utils/math_utils.go @@ -5,6 +5,7 @@ import ( "fmt" "math/big" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -55,7 +56,7 @@ func AsIntBig(value *big.Int) big.Int { func Divmod(n, m, p *big.Int) (big.Int, error) { // https://github.com/starkware-libs/cairo-lang/blob/efa9648f57568aad8f8a13fbf027d2de7c63c2c0/src/starkware/python/math_utils.py#L26 - a, _, c := Igcdex(m, p) + a, _, c := utils.Igcdex(m, p) if c.Cmp(big.NewInt(1)) != 0 { return *big.NewInt(0), errors.New("no solution exists (gcd(m, p) != 1)") } @@ -65,68 +66,6 @@ func Divmod(n, m, p *big.Int) (big.Int, error) { return *res, nil } -func Igcdex(a, b *big.Int) (big.Int, big.Int, big.Int) { - // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/core/intfunc.py#L362 - - if a.Cmp(big.NewInt(0)) == 0 && b.Cmp(big.NewInt(0)) == 0 { - return *big.NewInt(0), *big.NewInt(1), *big.NewInt(0) - } - g, x, y := gcdext(a, b) - return x, y, g -} - -func gcdext(a, b *big.Int) (big.Int, big.Int, big.Int) { - // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L125 - - if a.Cmp(big.NewInt(0)) == 0 || b.Cmp(big.NewInt(0)) == 0 { - g := new(big.Int) - if a.Cmp(big.NewInt(0)) == 0 { - g.Abs(b) - } else { - g.Abs(a) - } - - if g.Cmp(big.NewInt(0)) == 0 { - return *big.NewInt(0), *big.NewInt(0), *big.NewInt(0) - } - return *g, *new(big.Int).Div(a, g), *new(big.Int).Div(b, g) - } - - xSign, aSigned := sign(a) - ySign, bSigned := sign(b) - x, r := big.NewInt(1), big.NewInt(0) - y, s := big.NewInt(0), big.NewInt(1) - - for bSigned.Sign() != 0 { - q, c := new(big.Int).DivMod(&aSigned, &bSigned, new(big.Int)) - aSigned = bSigned - bSigned = *c - x, r = r, new(big.Int).Sub(x, new(big.Int).Mul(q, r)) - y, s = s, new(big.Int).Sub(y, new(big.Int).Mul(q, s)) - } - - return aSigned, *new(big.Int).Mul(x, big.NewInt(int64(xSign))), *new(big.Int).Mul(y, big.NewInt(int64(ySign))) -} - -func sign(n *big.Int) (int, big.Int) { - // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L119 - - if n.Sign() < 0 { - return -1, *new(big.Int).Abs(n) - } - return 1, *new(big.Int).Set(n) -} - -func SafeDiv(x, y *big.Int) (big.Int, error) { - if y.Cmp(big.NewInt(0)) == 0 { - return *big.NewInt(0), fmt.Errorf("division by zero") - } - if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 { - return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y) - } - return *new(big.Int).Div(x, y), nil -} - func IsQuadResidue(x *fp.Element) bool { // Implementation adapted from sympy implementation which can be found here : // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/ntheory/residue_ntheory.py#L689 diff --git a/pkg/hintrunner/utils/math_utils_test.go b/pkg/hintrunner/utils/math_utils_test.go index 3c1a0a609..42da85c2f 100644 --- a/pkg/hintrunner/utils/math_utils_test.go +++ b/pkg/hintrunner/utils/math_utils_test.go @@ -46,69 +46,3 @@ func TestDivMod(t *testing.T) { }) } } - -func TestIgcdex(t *testing.T) { - // https://github.com/sympy/sympy/blob/e7fb2714f17b30b83e424448aad0da9e94a4b577/sympy/core/tests/test_numbers.py#L278 - tests := []struct { - name string - a, b *big.Int - expectedX, expectedY, expectedG *big.Int - }{ - { - name: "Case 1", - a: big.NewInt(2), - b: big.NewInt(3), - expectedX: big.NewInt(-1), - expectedY: big.NewInt(1), - expectedG: big.NewInt(1), - }, - { - name: "Case 2", - a: big.NewInt(10), - b: big.NewInt(12), - expectedX: big.NewInt(-1), - expectedY: big.NewInt(1), - expectedG: big.NewInt(2), - }, - { - name: "Case 3", - a: big.NewInt(100), - b: big.NewInt(2004), - expectedX: big.NewInt(-20), - expectedY: big.NewInt(1), - expectedG: big.NewInt(4), - }, - { - name: "Case 4", - a: big.NewInt(0), - b: big.NewInt(0), - expectedX: big.NewInt(0), - expectedY: big.NewInt(1), - expectedG: big.NewInt(0), - }, - { - name: "Case 5", - a: big.NewInt(1), - b: big.NewInt(0), - expectedX: big.NewInt(1), - expectedY: big.NewInt(0), - expectedG: big.NewInt(1), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actualX, actualY, actualG := Igcdex(tt.a, tt.b) - - if actualX.Cmp(tt.expectedX) != 0 { - t.Errorf("got x: %v, want: %v", actualX, tt.expectedX) - } - if actualY.Cmp(tt.expectedY) != 0 { - t.Errorf("got x: %v, want: %v", actualY, tt.expectedY) - } - if actualG.Cmp(tt.expectedG) != 0 { - t.Errorf("got x: %v, want: %v", actualG, tt.expectedG) - } - }) - } -} diff --git a/pkg/hintrunner/zero/zerohint_ec.go b/pkg/hintrunner/zero/zerohint_ec.go index 011781669..de508956c 100644 --- a/pkg/hintrunner/zero/zerohint_ec.go +++ b/pkg/hintrunner/zero/zerohint_ec.go @@ -257,7 +257,7 @@ func newDivModNSafeDivPlusOneHint() hinter.Hinter { valueBig.Mul(resBig, bBig) valueBig.Sub(valueBig, aBig) - newValueBig, err := secp_utils.SafeDiv(valueBig, nBig) + newValueBig, err := utils.SafeDiv(valueBig, nBig) if err != nil { return err } diff --git a/pkg/hintrunner/zero/zerohint_signature.go b/pkg/hintrunner/zero/zerohint_signature.go index 8af7daf81..205842151 100644 --- a/pkg/hintrunner/zero/zerohint_signature.go +++ b/pkg/hintrunner/zero/zerohint_signature.go @@ -6,6 +6,7 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" secp_utils "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" @@ -372,7 +373,7 @@ func newDivModSafeDivHint() hinter.Hinter { } divisor := new(big.Int).Sub(new(big.Int).Mul(res, b), a) - value, err := secp_utils.SafeDiv(divisor, N) + value, err := utils.SafeDiv(divisor, N) if err != nil { return err } diff --git a/pkg/runner/runner_test.go b/pkg/runner/runner_test.go index d7b3d99cb..9a7274b19 100644 --- a/pkg/runner/runner_test.go +++ b/pkg/runner/runner_test.go @@ -410,6 +410,28 @@ func TestEcOpBuiltin(t *testing.T) { require.NoError(t, err) } +func TestModuloBuiltin(t *testing.T) { + // modulo is located at fp - 3 + // we first write 2048 and 5 to modulo + // then we read the modulo result from add and mul + // runner := createRunner(` + // [ap] = 2048; + // [ap] = [[fp - 3]]; + + // [ap + 1] = 5; + // [ap + 1] = [[fp - 3] + 1]; + // ret; + // `, "small", sn.AddMod, sn.MulMod) + + // err := runner.Run() + // require.NoError(t, err) + + // modulo, ok := runner.vm.Memory.FindSegmentWithBuiltin("add_mod") + // require.True(t, ok) + + // requireEqualSegments(t, createSegment(2048, 5), modulo) +} + func createRunner(code string, layoutName string, builtins ...builtins.BuiltinType) ZeroRunner { program := createProgramWithBuiltins(code, builtins...) diff --git a/pkg/utils/math.go b/pkg/utils/math.go index ea02fd0ca..40b310fc9 100644 --- a/pkg/utils/math.go +++ b/pkg/utils/math.go @@ -1,6 +1,8 @@ package utils import ( + "errors" + "fmt" "math" "math/big" "math/bits" @@ -156,3 +158,75 @@ func Int16FromBigInt(n *big.Int) (int16, bool) { func RightRot(value uint32, n uint32) uint32 { return (value >> n) | ((value & ((1 << n) - 1)) << (32 - n)) } + +func SafeDivUint64(x, y uint64) (uint64, error) { + if y == 0 { + return 0, fmt.Errorf("cannot divide: y division is zero") + } + if x%y != 0 { + return 0, errors.New("cannot divide: x is not divisible by y") + } + return x / y, nil +} + +func Igcdex(a, b *big.Int) (big.Int, big.Int, big.Int) { + // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/core/intfunc.py#L362 + + if a.Cmp(big.NewInt(0)) == 0 && b.Cmp(big.NewInt(0)) == 0 { + return *big.NewInt(0), *big.NewInt(1), *big.NewInt(0) + } + g, x, y := gcdext(a, b) + return x, y, g +} + +func gcdext(a, b *big.Int) (big.Int, big.Int, big.Int) { + // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L125 + + if a.Cmp(big.NewInt(0)) == 0 || b.Cmp(big.NewInt(0)) == 0 { + g := new(big.Int) + if a.Cmp(big.NewInt(0)) == 0 { + g.Abs(b) + } else { + g.Abs(a) + } + + if g.Cmp(big.NewInt(0)) == 0 { + return *big.NewInt(0), *big.NewInt(0), *big.NewInt(0) + } + return *g, *new(big.Int).Div(a, g), *new(big.Int).Div(b, g) + } + + xSign, aSigned := sign(a) + ySign, bSigned := sign(b) + x, r := big.NewInt(1), big.NewInt(0) + y, s := big.NewInt(0), big.NewInt(1) + + for bSigned.Sign() != 0 { + q, c := new(big.Int).DivMod(&aSigned, &bSigned, new(big.Int)) + aSigned = bSigned + bSigned = *c + x, r = r, new(big.Int).Sub(x, new(big.Int).Mul(q, r)) + y, s = s, new(big.Int).Sub(y, new(big.Int).Mul(q, s)) + } + + return aSigned, *new(big.Int).Mul(x, big.NewInt(int64(xSign))), *new(big.Int).Mul(y, big.NewInt(int64(ySign))) +} + +func sign(n *big.Int) (int, big.Int) { + // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L119 + + if n.Sign() < 0 { + return -1, *new(big.Int).Abs(n) + } + return 1, *new(big.Int).Set(n) +} + +func SafeDiv(x, y *big.Int) (big.Int, error) { + if y.Cmp(big.NewInt(0)) == 0 { + return *big.NewInt(0), fmt.Errorf("division by zero") + } + if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 { + return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y) + } + return *new(big.Int).Div(x, y), nil +} diff --git a/pkg/utils/math_test.go b/pkg/utils/math_test.go index 62cd381eb..0a455c072 100644 --- a/pkg/utils/math_test.go +++ b/pkg/utils/math_test.go @@ -1,6 +1,7 @@ package utils import ( + "math/big" "testing" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" @@ -119,3 +120,69 @@ func TestRightRot(t *testing.T) { }) } } + +func TestIgcdex(t *testing.T) { + // https://github.com/sympy/sympy/blob/e7fb2714f17b30b83e424448aad0da9e94a4b577/sympy/core/tests/test_numbers.py#L278 + tests := []struct { + name string + a, b *big.Int + expectedX, expectedY, expectedG *big.Int + }{ + { + name: "Case 1", + a: big.NewInt(2), + b: big.NewInt(3), + expectedX: big.NewInt(-1), + expectedY: big.NewInt(1), + expectedG: big.NewInt(1), + }, + { + name: "Case 2", + a: big.NewInt(10), + b: big.NewInt(12), + expectedX: big.NewInt(-1), + expectedY: big.NewInt(1), + expectedG: big.NewInt(2), + }, + { + name: "Case 3", + a: big.NewInt(100), + b: big.NewInt(2004), + expectedX: big.NewInt(-20), + expectedY: big.NewInt(1), + expectedG: big.NewInt(4), + }, + { + name: "Case 4", + a: big.NewInt(0), + b: big.NewInt(0), + expectedX: big.NewInt(0), + expectedY: big.NewInt(1), + expectedG: big.NewInt(0), + }, + { + name: "Case 5", + a: big.NewInt(1), + b: big.NewInt(0), + expectedX: big.NewInt(1), + expectedY: big.NewInt(0), + expectedG: big.NewInt(1), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actualX, actualY, actualG := Igcdex(tt.a, tt.b) + + if actualX.Cmp(tt.expectedX) != 0 { + t.Errorf("got x: %v, want: %v", actualX, tt.expectedX) + } + if actualY.Cmp(tt.expectedY) != 0 { + t.Errorf("got x: %v, want: %v", actualY, tt.expectedY) + } + if actualG.Cmp(tt.expectedG) != 0 { + t.Errorf("got x: %v, want: %v", actualG, tt.expectedG) + } + }) + } +} diff --git a/pkg/vm/builtins/builtin_runner.go b/pkg/vm/builtins/builtin_runner.go index 77aa50c9a..1b95fe77b 100644 --- a/pkg/vm/builtins/builtin_runner.go +++ b/pkg/vm/builtins/builtin_runner.go @@ -22,6 +22,8 @@ const ( PoseidonType SegmentArenaType RangeCheck96Type + AddModeType + MulModType ) func Runner(name BuiltinType) memory.BuiltinRunner { @@ -44,6 +46,10 @@ func Runner(name BuiltinType) memory.BuiltinRunner { return &EcOp{} case PoseidonType: return &Poseidon{} + case AddModeType: + return &ModBuiltin{modBuiltinType: Add} + case MulModType: + return &ModBuiltin{modBuiltinType: Mul} case SegmentArenaType: panic("Not implemented") default: @@ -101,6 +107,10 @@ func (b BuiltinType) MarshalJSON() ([]byte, error) { return []byte(EcOpName), nil case PoseidonType: return []byte(PoseidonName), nil + case AddModeType: + return []byte("Add" + ModuloName), nil + case MulModType: + return []byte("Mul" + ModuloName), nil case SegmentArenaType: return []byte(SegmentArenaName), nil @@ -133,6 +143,10 @@ func (b *BuiltinType) UnmarshalJSON(data []byte) error { *b = ECOPType case PoseidonName: *b = PoseidonType + case "Add" + ModuloName: + *b = AddModeType + case "Mul" + ModuloName: + *b = MulModType case SegmentArenaName: *b = SegmentArenaType default: diff --git a/pkg/vm/builtins/layouts.go b/pkg/vm/builtins/layouts.go index 19ff71893..dba5f1d76 100644 --- a/pkg/vm/builtins/layouts.go +++ b/pkg/vm/builtins/layouts.go @@ -110,8 +110,6 @@ func getAllSolidityLayout() Layout { }} } -// TODO: Add mul_mod and add_mod builtins -// refer: https://github.com/lambdaclass/cairo-vm/blob/main/vm/src/types/instance_definitions/builtins_instance_def.rs#L168 func getAllCairoLayout() Layout { return Layout{Name: "all_cairo", RcUnits: 8, Builtins: []LayoutBuiltin{ {Runner: &Output{}, Builtin: OutputType}, @@ -123,6 +121,8 @@ func getAllCairoLayout() Layout { {Runner: &Keccak{ratio: 2048, cache: make(map[uint64]fp.Element)}, Builtin: KeccakType}, {Runner: &Poseidon{ratio: 256, cache: make(map[uint64]fp.Element)}, Builtin: PoseidonType}, {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 6}, Builtin: RangeCheck96Type}, + {Runner: &ModBuiltin{ratio: 128, wordBitLen: 96, batchSize: 1, modBuiltinType: Add}, Builtin: AddModeType}, + {Runner: &ModBuiltin{ratio: 256, wordBitLen: 96, batchSize: 1, modBuiltinType: Mul}, Builtin: MulModType}, }} } diff --git a/pkg/vm/builtins/modulo.go b/pkg/vm/builtins/modulo.go new file mode 100644 index 000000000..4a3a97f87 --- /dev/null +++ b/pkg/vm/builtins/modulo.go @@ -0,0 +1,541 @@ +package builtins + +import ( + "fmt" + "math/big" + + "github.com/NethermindEth/cairo-vm-go/pkg/utils" + + "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" +) + +const ModuloName = "Mod" + +// These are the offsets in the array, which is used here as ModBuiltinInputs : +// INPUT_NAMES = [ +// +// "p0", +// "p1", +// "p2", +// "p3", +// "values_ptr", +// "offsets_ptr", +// "n", +// +// ] +const VALUES_PTR_OFFSET = 4 +const OFFSETS_PTR_OFFSET = 5 +const N_OFFSET = 6 + +// This is the number of felts in a UInt384 struct +const N_WORDS = 4 + +// number of memory cells per modulo builtin +// 4(felts) + 1(values_ptr) + 1(offsets_ptr) + 1(n) = 7 +const CELLS_PER_MOD = 7 + +// The maximum n value that the function fill_memory accepts +const MAX_N = 100000 + +// Represents a 384-bit unsigned integer d0 + 2**96 * d1 + 2**192 * d2 + 2**288 * d3 +// where each di is in [0, 2**96). +// +// struct UInt384 { +// d0: felt, +// d1: felt, +// d2: felt, +// d3: felt, +// } +// Instead of introducing UInt384, we use [N_WORDS]fp.Element to represent the 384-bit integer. + +type ModBuiltinInputs struct { + // The modulus. + p big.Int + pValues [N_WORDS]fp.Element + // A pointer to input values, the intermediate results and the output. + valuesPtr memory.MemoryAddress + // A pointer to offsets inside the values array, defining the circuit. + // The offsets array should contain 3 * n elements. + offsetsPtr memory.MemoryAddress + // The number of operations to perform. + n uint64 +} + +type ModBuiltinType string + +const ( + Add ModBuiltinType = "Add" + Mul ModBuiltinType = "Mul" +) + +type ModBuiltin struct { + ratio uint64 + // Add | Mul + modBuiltinType ModBuiltinType + // number of bits in a word + wordBitLen uint64 + batchSize uint64 + // shift by the number of bits present in a word + shift big.Int + // powers required to do the corresponding shift + shiftPowers [N_WORDS]big.Int + // k value that bounds p when finding unknown value in fillValue function + kBound *big.Int +} + +func NewModBuiltin(ratio uint64, wordBitLen uint64, batchSize uint64, modBuiltinType ModBuiltinType) *ModBuiltin { + shift := new(big.Int).Lsh(big.NewInt(1), uint(wordBitLen)) + shiftPowers := [N_WORDS]big.Int{} + shiftPowers[0] = *big.NewInt(1) + for i := 1; i < N_WORDS; i++ { + shiftPowers[i].Mul(&shiftPowers[i-1], shift) + } + kBound := big.NewInt(2) + if modBuiltinType == Mul { + kBound = nil + } + return &ModBuiltin{ + ratio: ratio, + modBuiltinType: modBuiltinType, + wordBitLen: wordBitLen, + batchSize: batchSize, + shift: *shift, + shiftPowers: shiftPowers, + kBound: kBound, + } +} + +func (m *ModBuiltin) CheckWrite(segment *memory.Segment, offset uint64, value *memory.MemoryValue) error { + return nil +} + +func (m *ModBuiltin) InferValue(segment *memory.Segment, offset uint64) error { + return fmt.Errorf("can't infer value") +} + +func (m *ModBuiltin) String() string { + return string(m.modBuiltinType) + ModuloName +} + +func (m *ModBuiltin) GetAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64) (uint64, error) { + return 0, nil +} + +// Reads N_WORDS from memory, starting at address = addr. +// Returns the words and the value if all words are in memory. +// Verifies that all words are integers and are bounded by 2**wordBitLen. +func (m *ModBuiltin) readNWordsValue(memory *memory.Memory, addr memory.MemoryAddress) ([N_WORDS]fp.Element, *big.Int, error) { + var words [N_WORDS]fp.Element + value := new(big.Int).SetInt64(0) + + for i := 0; i < N_WORDS; i++ { + newAddr, err := addr.AddOffset(int16(i)) + if err != nil { + return [N_WORDS]fp.Element{}, nil, err + } + + wordFelt, err := memory.ReadAsElement(newAddr.SegmentIndex, newAddr.Offset) + if err != nil { + return [N_WORDS]fp.Element{}, nil, err + } + + var word big.Int + wordFelt.BigInt(&word) + if word.Cmp(&m.shift) >= 0 { + return [N_WORDS]fp.Element{}, nil, fmt.Errorf("expected integer at address %d:%d to be smaller than 2^%d. Got: %s", newAddr.SegmentIndex, newAddr.Offset, m.wordBitLen, word.String()) + } + + words[i] = wordFelt + value = new(big.Int).Add(value, new(big.Int).Mul(&word, &m.shiftPowers[i])) + } + + return words, value, nil +} + +// Reads the inputs to the builtin (p, p_values, values_ptr, offsets_ptr, n) from the memory at address = addr. +// Returns an instance of ModBuiltinInputs and asserts that it exists in memory. +// If `read_n` is false, avoid reading and validating the value of 'n'. +func (m *ModBuiltin) readInputs(mem *memory.Memory, addr memory.MemoryAddress, read_n bool) (ModBuiltinInputs, error) { + valuesPtrAddr, err := addr.AddOffset(int16(VALUES_PTR_OFFSET)) + if err != nil { + return ModBuiltinInputs{}, err + } + valuesPtr, err := mem.ReadAsAddress(&valuesPtrAddr) + if err != nil { + return ModBuiltinInputs{}, err + } + offsetsPtrAddr, err := addr.AddOffset(int16(OFFSETS_PTR_OFFSET)) + if err != nil { + return ModBuiltinInputs{}, err + } + offsetsPtr, err := mem.ReadAsAddress(&offsetsPtrAddr) + if err != nil { + return ModBuiltinInputs{}, err + } + n := uint64(0) + if read_n { + nFelt, err := mem.ReadAsElement(addr.SegmentIndex, addr.Offset+N_OFFSET) + if err != nil { + return ModBuiltinInputs{}, err + } + n = nFelt.Uint64() + if n < 1 { + return ModBuiltinInputs{}, fmt.Errorf("moduloBuiltin: Expected n >= 1. Got: %d", n) + } + } + pValues, p, err := m.readNWordsValue(mem, addr) + if err != nil { + return ModBuiltinInputs{}, err + } + return ModBuiltinInputs{ + p: *p, + pValues: pValues, + valuesPtr: valuesPtr, + n: n, + offsetsPtr: offsetsPtr, + }, nil +} + +// Fills the inputs to the instances of the builtin given the inputs to the first instance. +func (m *ModBuiltin) fillInputs(mem *memory.Memory, builtinPtr memory.MemoryAddress, inputs ModBuiltinInputs) error { + if inputs.n > MAX_N { + return fmt.Errorf("fill memory max exceeded") + } + + nInstances, err := utils.SafeDivUint64(inputs.n, m.batchSize) + if err != nil { + return err + } + + for instance := 1; instance < int(nInstances); instance++ { + instancePtr, err := builtinPtr.AddOffset(int16(instance * CELLS_PER_MOD)) + if err != nil { + return err + } + + // Filling the 4 values of a UInt384 struct + for i := 0; i < N_WORDS; i++ { + addr, err := instancePtr.AddOffset(int16(i)) + if err != nil { + return err + } + mv := memory.MemoryValueFromFieldElement(&inputs.pValues[i]) + if err := mem.WriteToAddress(&addr, &mv); err != nil { + return err + } + } + + addr, err := instancePtr.AddOffset(VALUES_PTR_OFFSET) + if err != nil { + return err + } + mv := memory.MemoryValueFromMemoryAddress(&inputs.valuesPtr) + if err := mem.WriteToAddress(&addr, &mv); err != nil { + return err + } + + addr, err = instancePtr.AddOffset(OFFSETS_PTR_OFFSET) + if err != nil { + return err + } + newAddr, err := inputs.offsetsPtr.AddOffset(3 * int16(instance) * int16(m.batchSize)) + if err != nil { + return err + } + mv = memory.MemoryValueFromMemoryAddress(&newAddr) + if err := mem.WriteToAddress(&addr, &mv); err != nil { + return err + } + + // This denotes the number of operations left + // n for new instance = original n - batch_size * (number of instances passed) + addr, err = instancePtr.AddOffset(N_OFFSET) + if err != nil { + return err + } + val := fp.NewElement(inputs.n - m.batchSize*uint64(instance)) + mv = memory.MemoryValueFromFieldElement(&val) + if err := mem.WriteToAddress(&addr, &mv); err != nil { + return err + } + } + + return nil +} + +// Copies the first offsets into memory, nCopies times. +func (m *ModBuiltin) fillOffsets(mem *memory.Memory, offsetsPtr memory.MemoryAddress, index, nCopies uint64) error { + if nCopies == 0 { + return nil + } + + for i := 0; i < 3; i++ { + addr, err := offsetsPtr.AddOffset(int16(i)) + if err != nil { + return err + } + + offset, err := mem.ReadAsAddress(&addr) + if err != nil { + return err + } + + for copyI := 0; copyI < int(nCopies); copyI++ { + copyAddr, err := offsetsPtr.AddOffset(int16(3*(index+uint64(copyI)) + uint64(i))) + if err != nil { + return err + } + mv := memory.MemoryValueFromMemoryAddress(&offset) + if err := mem.WriteToAddress(©Addr, &mv); err != nil { + return err + } + } + } + + return nil +} + +// Given a value, writes its n_words to memory, starting at address = addr. +func (m *ModBuiltin) writeNWordsValue(mem *memory.Memory, addr memory.MemoryAddress, value big.Int) error { + for i := 0; i < N_WORDS; i++ { + word := new(big.Int).Mod(&value, &m.shift) + modAddr, err := addr.AddOffset(int16(i)) + if err != nil { + return err + } + mv := memory.MemoryValueFromFieldElement(new(fp.Element).SetBigInt(word)) + if err := mem.WriteToAddress(&modAddr, &mv); err != nil { + return err + } + value.Div(&value, &m.shift) + } + if value.Sign() != 0 { + return fmt.Errorf("writeNWordsValue: value should be zero") + } + return nil +} + +// Fills a value in the values table, if exactly one value is missing. +// Returns true on success or if all values are already known. +// Given known, res, p fillValue tries to compute the minimal integer operand x which +// satisfies the equation op(x,known) = res + k*p for some k in {0,1,...,self.k_bound-1}. +func (m *ModBuiltin) fillValue(mem *memory.Memory, inputs ModBuiltinInputs, index int, op ModBuiltinType) (bool, error) { + addresses := make([]memory.MemoryAddress, 0, 3) + values := make([]*big.Int, 0, 3) + + for i := 0; i < 3; i++ { + addr, err := inputs.offsetsPtr.AddOffset(int16(3*index + i)) + if err != nil { + return false, err + } + offsetFelt, err := mem.ReadAsElement(addr.SegmentIndex, addr.Offset) + if err != nil { + return false, err + } + offset := offsetFelt.Uint64() + addr, err = inputs.valuesPtr.AddOffset(int16(offset)) + if err != nil { + return false, err + } + addresses = append(addresses, addr) + // do not check for error, as the value might not be in memory + _, value, _ := m.readNWordsValue(mem, addr) + values = append(values, value) + } + + a, b, c := values[0], values[1], values[2] + + // 2 ** 384 (max value that can be stored in 4 felts) + intLim := new(big.Int).Lsh(big.NewInt(1), uint(m.wordBitLen)*N_WORDS) + kBound := m.kBound + if kBound == nil { + kBound = new(big.Int).Set(intLim) + } + + switch { + case a != nil && b != nil && c == nil: + var value big.Int + if op == Add { + value = *new(big.Int).Add(a, b) + } else { + value = *new(big.Int).Mul(a, b) + } + // value - (kBound - 1) * p <= intLim - 1 + if new(big.Int).Sub(&value, new(big.Int).Mul((new(big.Int).Sub(kBound, big.NewInt(1))), &inputs.p)).Cmp(new(big.Int).Sub(intLim, big.NewInt(1))) == 1 { + return false, fmt.Errorf("%s builtin: Expected a %s b - %d * p <= %d", m.String(), m.modBuiltinType, kBound.Sub(kBound, big.NewInt(1)), intLim.Sub(intLim, big.NewInt(1))) + } + if value.Cmp(new(big.Int).Mul(kBound, &inputs.p)) < 0 { + value.Mod(&value, &inputs.p) + } else { + value.Sub(&value, new(big.Int).Mul(new(big.Int).Sub(kBound, big.NewInt(1)), &inputs.p)) + } + if err := m.writeNWordsValue(mem, addresses[2], value); err != nil { + return false, err + } + return true, nil + case a != nil && b == nil && c != nil: + var value big.Int + if op == Add { + // Right now only k = 2 is an option, hence as we stated above that x + known can only take values + // from res to res + (k - 1) * p, hence known <= res + p + if a.Cmp(new(big.Int).Add(c, &inputs.p)) > 0 { + return false, fmt.Errorf("%s builtin: addend greater than sum + p: %d > %d + %d", m.String(), a, c, &inputs.p) + } else { + if a.Cmp(c) <= 0 { + value = *new(big.Int).Sub(c, a) + } else { + value = *new(big.Int).Sub(c.Add(c, &inputs.p), a) + } + } + } else { + x, _, gcd := utils.Igcdex(a, &inputs.p) + // if gcd != 1, the known value is 0, in which case the res must be 0 + if gcd.Cmp(big.NewInt(1)) != 0 { + value = *new(big.Int).Div(&inputs.p, &gcd) + } else { + value = *new(big.Int).Mul(c, &x) + value = *value.Mod(&value, &inputs.p) + tmpK, err := utils.SafeDiv(new(big.Int).Sub(new(big.Int).Mul(a, &value), c), &inputs.p) + if err != nil { + return false, err + } + if tmpK.Cmp(kBound) >= 0 { + return false, fmt.Errorf("%s builtin: ((%d * q) - %d) / %d > %d for any q > 0, such that %d * q = %d (mod %d) ", m.String(), a, c, &inputs.p, kBound, a, c, &inputs.p) + } + if tmpK.Cmp(big.NewInt(0)) < 0 { + value = *value.Add(&value, new(big.Int).Mul(&inputs.p, new(big.Int).Div(new(big.Int).Sub(a, new(big.Int).Sub(&tmpK, big.NewInt(1))), a))) + } + } + } + if err := m.writeNWordsValue(mem, addresses[1], value); err != nil { + return false, err + } + return true, nil + case a == nil && b != nil && c != nil: + var value big.Int + if op == Add { + // Right now only k = 2 is an option, hence as we stated above that x + known can only take values + // from res to res + (k - 1) * p, hence known <= res + p + if b.Cmp(new(big.Int).Add(c, &inputs.p)) > 0 { + return false, fmt.Errorf("%s builtin: addend greater than sum + p: %d > %d + %d", m.String(), b, c, &inputs.p) + } else { + if b.Cmp(c) <= 0 { + value = *new(big.Int).Sub(c, b) + } else { + value = *new(big.Int).Sub(c.Add(c, &inputs.p), b) + } + } + } else { + x, _, gcd := utils.Igcdex(b, &inputs.p) + // if gcd != 1, the known value is 0, in which case the res must be 0 + if gcd.Cmp(big.NewInt(1)) != 0 { + value = *new(big.Int).Div(&inputs.p, &gcd) + } else { + value = *new(big.Int).Mul(c, &x) + value = *value.Mod(&value, &inputs.p) + tmpK, err := utils.SafeDiv(new(big.Int).Sub(new(big.Int).Mul(b, &value), c), &inputs.p) + if err != nil { + return false, err + } + if tmpK.Cmp(kBound) >= 0 { + return false, fmt.Errorf("%s builtin: ((%d * q) - %d) / %d > %d for any q > 0, such that %d * q = %d (mod %d) ", m.String(), b, c, &inputs.p, kBound, b, c, &inputs.p) + } + if tmpK.Cmp(big.NewInt(0)) < 0 { + value = *value.Add(&value, new(big.Int).Mul(&inputs.p, new(big.Int).Div(new(big.Int).Sub(b, new(big.Int).Sub(&tmpK, big.NewInt(1))), b))) + } + } + } + if err := m.writeNWordsValue(mem, addresses[0], value); err != nil { + return false, err + } + return true, nil + case a != nil && b != nil && c != nil: + return true, nil + default: + return false, nil + } +} + +// Fills the memory with inputs to the builtin instances based on the inputs to the +// first instance, pads the offsets table to fit the number of operations written in the +// input to the first instance, and calculates missing values in the values table. +// +// The number of operations written to the input of the first instance n should be at +// least n and a multiple of batch_size. Previous offsets are copied to the end of the +// offsets table to make its length 3n'. +func FillMemory(mem *memory.Memory, addModBuiltinAddr memory.MemoryAddress, nAddModsIndex uint64, mulModBuiltinAddr memory.MemoryAddress, nMulModsIndex uint64) error { + if nAddModsIndex > MAX_N { + return fmt.Errorf("AddMod builtin: n must be <= {MAX_N}") + } + if nMulModsIndex > MAX_N { + return fmt.Errorf("MulMod builtin: n must be <= {MAX_N}") + } + + addModBuiltinSegment, ok := mem.FindSegmentWithBuiltin("AddMod") + if !ok { + return fmt.Errorf("AddMod builtin segment doesn't exist") + } + mulModBuiltinSegment, ok := mem.FindSegmentWithBuiltin("MulMod") + if !ok { + return fmt.Errorf("MulMod builtin segment doesn't exist") + } + addModBuiltinRunner, ok := addModBuiltinSegment.BuiltinRunner.(*ModBuiltin) + if !ok { + return fmt.Errorf("addModBuiltinRunner is not a ModBuiltin") + } + mulModBuiltinRunner, ok := mulModBuiltinSegment.BuiltinRunner.(*ModBuiltin) + if !ok { + return fmt.Errorf("mulModBuiltinRunner is not a ModBuiltin") + } + + if addModBuiltinRunner.wordBitLen != mulModBuiltinRunner.wordBitLen { + return fmt.Errorf("AddMod and MulMod wordBitLen mismatch") + } + + addModBuiltinInputs, err := addModBuiltinRunner.readInputs(mem, addModBuiltinAddr, true) + if err != nil { + return err + } + if err := addModBuiltinRunner.fillInputs(mem, addModBuiltinAddr, addModBuiltinInputs); err != nil { + return err + } + if err := addModBuiltinRunner.fillOffsets(mem, addModBuiltinInputs.offsetsPtr, nAddModsIndex, addModBuiltinInputs.n-nAddModsIndex); err != nil { + return err + } + + mulModBuiltinInputs, err := mulModBuiltinRunner.readInputs(mem, mulModBuiltinAddr, true) + if err != nil { + return err + } + if err := mulModBuiltinRunner.fillInputs(mem, mulModBuiltinAddr, mulModBuiltinInputs); err != nil { + return err + } + if err := mulModBuiltinRunner.fillOffsets(mem, mulModBuiltinInputs.offsetsPtr, nMulModsIndex, mulModBuiltinInputs.n-nMulModsIndex); err != nil { + return err + } + + addModIndex, mulModIndex := uint64(0), uint64(0) + for addModIndex < nAddModsIndex { + ok, err := addModBuiltinRunner.fillValue(mem, addModBuiltinInputs, int(addModIndex), Add) + if err != nil { + return err + } + if ok { + addModIndex++ + } + } + + for mulModIndex < nMulModsIndex { + ok, err = mulModBuiltinRunner.fillValue(mem, mulModBuiltinInputs, int(mulModIndex), Mul) + if err != nil { + return err + } + if ok { + mulModIndex++ + } + } + // POTENTIALY: add n_computed_mul_gates features in the future + + return nil +} diff --git a/pkg/vm/builtins/modulo_test.go b/pkg/vm/builtins/modulo_test.go new file mode 100644 index 000000000..722817b72 --- /dev/null +++ b/pkg/vm/builtins/modulo_test.go @@ -0,0 +1,158 @@ +package builtins + +import ( + // "fmt" + "math/big" + "testing" + + "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" + "github.com/stretchr/testify/require" +) + +/* +Tests whether runner completes a trio a, b, c as the input implies: +If inverse is False it tests whether a = x1, b=x2, c = None will be completed with c = res. +If inverse is True it tests whether c = x1, b = x2, a = None will be completed with a = res. +*/ +func checkResult(runner ModBuiltin, inverse bool, p, x1, x2 big.Int) (*big.Int, error) { + mem := memory.Memory{} + + mem.AllocateBuiltinSegment(&runner) + + offsetsPtr := memory.MemoryAddress{SegmentIndex: 0, Offset: 0} + + for i := 0; i < 3; i++ { + offsetsPtrAddr, err := offsetsPtr.AddOffset(int16(i)) + if err != nil { + return nil, err + } + + mv := memory.MemoryValueFromInt(i * N_WORDS) + if err := mem.WriteToAddress(&offsetsPtrAddr, &mv); err != nil { + return nil, err + } + } + + valuesAddr := memory.MemoryAddress{SegmentIndex: 0, Offset: 24} + + x1Addr, err := valuesAddr.AddOffset(int16(0)) + if err != nil { + return nil, err + } + + x2Addr, err := valuesAddr.AddOffset(int16(N_WORDS)) + if err != nil { + return nil, err + } + err = runner.writeNWordsValue(&mem, x2Addr, x2) + if err != nil { + return nil, err + } + + resAddr, err := valuesAddr.AddOffset(int16(2 * N_WORDS)) + if err != nil { + return nil, err + } + + if inverse { + x1Addr, resAddr = resAddr, x1Addr + } + + err = runner.writeNWordsValue(&mem, x1Addr, x1) + if err != nil { + return nil, err + } + + _, err = runner.fillValue(&mem, ModBuiltinInputs{ + p: p, + pValues: [N_WORDS]fp.Element{}, // not used in fillValue + valuesPtr: valuesAddr, + n: 0, // not used in fillValue + offsetsPtr: offsetsPtr, + }, 0, runner.modBuiltinType) + + if err != nil { + return nil, err + } + + _, OutRes, err := runner.readNWordsValue(&mem, resAddr) + if err != nil { + return nil, err + } + + return OutRes, nil +} + +func TestAddModBuiltinRunnerAddition(t *testing.T) { + runner := NewModBuiltin(1, 3, 1, Add) + res1, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(17), *big.NewInt(40)) + require.NoError(t, err) + require.Equal(t, big.NewInt(57), res1) + res2, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(82), *big.NewInt(31)) + require.NoError(t, err) + require.Equal(t, big.NewInt(46), res2) + res3, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(68), *big.NewInt(69)) + require.NoError(t, err) + require.Equal(t, big.NewInt(70), res3) + res4, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(68), *big.NewInt(0)) + require.NoError(t, err) + require.Equal(t, big.NewInt(1), res4) + _, err = checkResult(*runner, false, *big.NewInt(4094), *big.NewInt(4095), *big.NewInt(4095)) + require.ErrorContains(t, err, "Expected a Add b - 1 * p <= 4095") +} + +func TestAddModBuiltinRunnerSubtraction(t *testing.T) { + runner := NewModBuiltin(1, 3, 1, Add) + res1, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(52), *big.NewInt(38)) + require.NoError(t, err) + require.Equal(t, big.NewInt(14), res1) + res2, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(5), *big.NewInt(68)) + require.NoError(t, err) + require.Equal(t, big.NewInt(4), res2) + res3, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(5), *big.NewInt(0)) + require.NoError(t, err) + require.Equal(t, big.NewInt(5), res3) + res4, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(0), *big.NewInt(5)) + require.NoError(t, err) + require.Equal(t, big.NewInt(62), res4) + _, err = checkResult(*runner, true, *big.NewInt(67), *big.NewInt(70), *big.NewInt(138)) + require.ErrorContains(t, err, "addend greater than sum + p") +} + +func TestMulModBuiltinRunnerMultiplication(t *testing.T) { + runner := NewModBuiltin(1, 3, 1, Mul) + res1, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(11), *big.NewInt(8)) + require.NoError(t, err) + require.Equal(t, big.NewInt(21), res1) + res2, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(68), *big.NewInt(69)) + require.NoError(t, err) + require.Equal(t, big.NewInt(2), res2) + res3, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(525), *big.NewInt(526)) + require.NoError(t, err) + require.Equal(t, big.NewInt(1785), res3) + res4, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(525), *big.NewInt(0)) + require.NoError(t, err) + require.Equal(t, big.NewInt(0), res4) + _, err = checkResult(*runner, false, *big.NewInt(67), *big.NewInt(3777), *big.NewInt(3989)) + require.ErrorContains(t, err, "Expected a Mul b - 4095 * p <= 4095") +} + +func TestMulModBuiltinRunnerDivision(t *testing.T) { + runner := NewModBuiltin(1, 3, 1, Mul) + res1, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(36), *big.NewInt(9)) + require.NoError(t, err) + require.Equal(t, big.NewInt(4), res1) + res2, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(138), *big.NewInt(41)) + require.NoError(t, err) + require.Equal(t, big.NewInt(5), res2) + res3, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(272), *big.NewInt(41)) + require.NoError(t, err) + require.Equal(t, big.NewInt(72), res3) + res4, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(0), *big.NewInt(0)) + require.NoError(t, err) + require.Equal(t, big.NewInt(1), res4) + res5, err := checkResult(*runner, true, *big.NewInt(66), *big.NewInt(6), *big.NewInt(3)) + require.NoError(t, err) + require.Equal(t, big.NewInt(22), res5) +}