Skip to content

Commit

Permalink
Implement ModBuiltin (#666)
Browse files Browse the repository at this point in the history
* implement partial functionalities and structs required

* run fmt

* Introduce max func

* Added more functions

* ran fmt

* Add test for modulo

* Added structure for integrating builtin with the vm

* Added useful comments

* Update FillMemory

* Added more comments

* Add test

* Added TODOs

* fixes

* update loop

* Added addmod and mulmod

* added some fixes

* test passes now

* fix error and test

* Make tests pass

* mod subtraction tests

* Added some comments

* Added subtraction tests

* Added recursive case as well

* nit

* Add multiplication test

* refactor math utils

* move one math utils function

* All Tests pass

* nit

* Added mirror functionality

* remove redundant test

* updated checkwrite and infervalue

* nit
  • Loading branch information
Sh0g0-1758 authored Sep 27, 2024
1 parent 9afede6 commit f1f925d
Show file tree
Hide file tree
Showing 12 changed files with 884 additions and 134 deletions.
2 changes: 1 addition & 1 deletion pkg/hintrunner/core/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func (hint U256InvModN) Execute(vm *VM.VirtualMachine, _ *hinter.HintRunnerConte
n := new(big.Int).Lsh(&N1BigInt, 128)
n.Add(n, &N0BigInt)

_, r, g := u.Igcdex(n, b)
_, r, g := utils.Igcdex(n, b)
mask := new(big.Int).Lsh(big.NewInt(1), 128)
mask.Sub(mask, big.NewInt(1))

Expand Down
65 changes: 2 additions & 63 deletions pkg/hintrunner/utils/math_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math/big"

"github.com/NethermindEth/cairo-vm-go/pkg/utils"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

Expand Down Expand Up @@ -55,7 +56,7 @@ func AsIntBig(value *big.Int) big.Int {
func Divmod(n, m, p *big.Int) (big.Int, error) {
// https://github.com/starkware-libs/cairo-lang/blob/efa9648f57568aad8f8a13fbf027d2de7c63c2c0/src/starkware/python/math_utils.py#L26

a, _, c := Igcdex(m, p)
a, _, c := utils.Igcdex(m, p)
if c.Cmp(big.NewInt(1)) != 0 {
return *big.NewInt(0), errors.New("no solution exists (gcd(m, p) != 1)")
}
Expand All @@ -65,68 +66,6 @@ func Divmod(n, m, p *big.Int) (big.Int, error) {
return *res, nil
}

func Igcdex(a, b *big.Int) (big.Int, big.Int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/core/intfunc.py#L362

if a.Cmp(big.NewInt(0)) == 0 && b.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), *big.NewInt(1), *big.NewInt(0)
}
g, x, y := gcdext(a, b)
return x, y, g
}

func gcdext(a, b *big.Int) (big.Int, big.Int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L125

if a.Cmp(big.NewInt(0)) == 0 || b.Cmp(big.NewInt(0)) == 0 {
g := new(big.Int)
if a.Cmp(big.NewInt(0)) == 0 {
g.Abs(b)
} else {
g.Abs(a)
}

if g.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), *big.NewInt(0), *big.NewInt(0)
}
return *g, *new(big.Int).Div(a, g), *new(big.Int).Div(b, g)
}

xSign, aSigned := sign(a)
ySign, bSigned := sign(b)
x, r := big.NewInt(1), big.NewInt(0)
y, s := big.NewInt(0), big.NewInt(1)

for bSigned.Sign() != 0 {
q, c := new(big.Int).DivMod(&aSigned, &bSigned, new(big.Int))
aSigned = bSigned
bSigned = *c
x, r = r, new(big.Int).Sub(x, new(big.Int).Mul(q, r))
y, s = s, new(big.Int).Sub(y, new(big.Int).Mul(q, s))
}

return aSigned, *new(big.Int).Mul(x, big.NewInt(int64(xSign))), *new(big.Int).Mul(y, big.NewInt(int64(ySign)))
}

func sign(n *big.Int) (int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L119

if n.Sign() < 0 {
return -1, *new(big.Int).Abs(n)
}
return 1, *new(big.Int).Set(n)
}

func SafeDiv(x, y *big.Int) (big.Int, error) {
if y.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), fmt.Errorf("division by zero")
}
if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 {
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y)
}
return *new(big.Int).Div(x, y), nil
}

func IsQuadResidue(x *fp.Element) bool {
// Implementation adapted from sympy implementation which can be found here :
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/ntheory/residue_ntheory.py#L689
Expand Down
66 changes: 0 additions & 66 deletions pkg/hintrunner/utils/math_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,69 +46,3 @@ func TestDivMod(t *testing.T) {
})
}
}

func TestIgcdex(t *testing.T) {
// https://github.com/sympy/sympy/blob/e7fb2714f17b30b83e424448aad0da9e94a4b577/sympy/core/tests/test_numbers.py#L278
tests := []struct {
name string
a, b *big.Int
expectedX, expectedY, expectedG *big.Int
}{
{
name: "Case 1",
a: big.NewInt(2),
b: big.NewInt(3),
expectedX: big.NewInt(-1),
expectedY: big.NewInt(1),
expectedG: big.NewInt(1),
},
{
name: "Case 2",
a: big.NewInt(10),
b: big.NewInt(12),
expectedX: big.NewInt(-1),
expectedY: big.NewInt(1),
expectedG: big.NewInt(2),
},
{
name: "Case 3",
a: big.NewInt(100),
b: big.NewInt(2004),
expectedX: big.NewInt(-20),
expectedY: big.NewInt(1),
expectedG: big.NewInt(4),
},
{
name: "Case 4",
a: big.NewInt(0),
b: big.NewInt(0),
expectedX: big.NewInt(0),
expectedY: big.NewInt(1),
expectedG: big.NewInt(0),
},
{
name: "Case 5",
a: big.NewInt(1),
b: big.NewInt(0),
expectedX: big.NewInt(1),
expectedY: big.NewInt(0),
expectedG: big.NewInt(1),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actualX, actualY, actualG := Igcdex(tt.a, tt.b)

if actualX.Cmp(tt.expectedX) != 0 {
t.Errorf("got x: %v, want: %v", actualX, tt.expectedX)
}
if actualY.Cmp(tt.expectedY) != 0 {
t.Errorf("got x: %v, want: %v", actualY, tt.expectedY)
}
if actualG.Cmp(tt.expectedG) != 0 {
t.Errorf("got x: %v, want: %v", actualG, tt.expectedG)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/hintrunner/zero/zerohint_ec.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func newDivModNSafeDivPlusOneHint() hinter.Hinter {
valueBig.Mul(resBig, bBig)
valueBig.Sub(valueBig, aBig)

newValueBig, err := secp_utils.SafeDiv(valueBig, nBig)
newValueBig, err := utils.SafeDiv(valueBig, nBig)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/hintrunner/zero/zerohint_signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
secp_utils "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins"
mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
Expand Down Expand Up @@ -372,7 +373,7 @@ func newDivModSafeDivHint() hinter.Hinter {
}

divisor := new(big.Int).Sub(new(big.Int).Mul(res, b), a)
value, err := secp_utils.SafeDiv(divisor, N)
value, err := utils.SafeDiv(divisor, N)
if err != nil {
return err
}
Expand Down
22 changes: 22 additions & 0 deletions pkg/runner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,28 @@ func TestEcOpBuiltin(t *testing.T) {
require.NoError(t, err)
}

func TestModuloBuiltin(t *testing.T) {
// modulo is located at fp - 3
// we first write 2048 and 5 to modulo
// then we read the modulo result from add and mul
// runner := createRunner(`
// [ap] = 2048;
// [ap] = [[fp - 3]];

// [ap + 1] = 5;
// [ap + 1] = [[fp - 3] + 1];
// ret;
// `, "small", sn.AddMod, sn.MulMod)

// err := runner.Run()
// require.NoError(t, err)

// modulo, ok := runner.vm.Memory.FindSegmentWithBuiltin("add_mod")
// require.True(t, ok)

// requireEqualSegments(t, createSegment(2048, 5), modulo)
}

func createRunner(code string, layoutName string, builtins ...builtins.BuiltinType) ZeroRunner {
program := createProgramWithBuiltins(code, builtins...)

Expand Down
74 changes: 74 additions & 0 deletions pkg/utils/math.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package utils

import (
"errors"
"fmt"
"math"
"math/big"
"math/bits"
Expand Down Expand Up @@ -156,3 +158,75 @@ func Int16FromBigInt(n *big.Int) (int16, bool) {
func RightRot(value uint32, n uint32) uint32 {
return (value >> n) | ((value & ((1 << n) - 1)) << (32 - n))
}

func SafeDivUint64(x, y uint64) (uint64, error) {
if y == 0 {
return 0, fmt.Errorf("cannot divide: y division is zero")
}
if x%y != 0 {
return 0, errors.New("cannot divide: x is not divisible by y")
}
return x / y, nil
}

func Igcdex(a, b *big.Int) (big.Int, big.Int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/core/intfunc.py#L362

if a.Cmp(big.NewInt(0)) == 0 && b.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), *big.NewInt(1), *big.NewInt(0)
}
g, x, y := gcdext(a, b)
return x, y, g
}

func gcdext(a, b *big.Int) (big.Int, big.Int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L125

if a.Cmp(big.NewInt(0)) == 0 || b.Cmp(big.NewInt(0)) == 0 {
g := new(big.Int)
if a.Cmp(big.NewInt(0)) == 0 {
g.Abs(b)
} else {
g.Abs(a)
}

if g.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), *big.NewInt(0), *big.NewInt(0)
}
return *g, *new(big.Int).Div(a, g), *new(big.Int).Div(b, g)
}

xSign, aSigned := sign(a)
ySign, bSigned := sign(b)
x, r := big.NewInt(1), big.NewInt(0)
y, s := big.NewInt(0), big.NewInt(1)

for bSigned.Sign() != 0 {
q, c := new(big.Int).DivMod(&aSigned, &bSigned, new(big.Int))
aSigned = bSigned
bSigned = *c
x, r = r, new(big.Int).Sub(x, new(big.Int).Mul(q, r))
y, s = s, new(big.Int).Sub(y, new(big.Int).Mul(q, s))
}

return aSigned, *new(big.Int).Mul(x, big.NewInt(int64(xSign))), *new(big.Int).Mul(y, big.NewInt(int64(ySign)))
}

func sign(n *big.Int) (int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L119

if n.Sign() < 0 {
return -1, *new(big.Int).Abs(n)
}
return 1, *new(big.Int).Set(n)
}

func SafeDiv(x, y *big.Int) (big.Int, error) {
if y.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), fmt.Errorf("division by zero")
}
if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 {
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y)
}
return *new(big.Int).Div(x, y), nil
}
67 changes: 67 additions & 0 deletions pkg/utils/math_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package utils

import (
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
Expand Down Expand Up @@ -119,3 +120,69 @@ func TestRightRot(t *testing.T) {
})
}
}

func TestIgcdex(t *testing.T) {
// https://github.com/sympy/sympy/blob/e7fb2714f17b30b83e424448aad0da9e94a4b577/sympy/core/tests/test_numbers.py#L278
tests := []struct {
name string
a, b *big.Int
expectedX, expectedY, expectedG *big.Int
}{
{
name: "Case 1",
a: big.NewInt(2),
b: big.NewInt(3),
expectedX: big.NewInt(-1),
expectedY: big.NewInt(1),
expectedG: big.NewInt(1),
},
{
name: "Case 2",
a: big.NewInt(10),
b: big.NewInt(12),
expectedX: big.NewInt(-1),
expectedY: big.NewInt(1),
expectedG: big.NewInt(2),
},
{
name: "Case 3",
a: big.NewInt(100),
b: big.NewInt(2004),
expectedX: big.NewInt(-20),
expectedY: big.NewInt(1),
expectedG: big.NewInt(4),
},
{
name: "Case 4",
a: big.NewInt(0),
b: big.NewInt(0),
expectedX: big.NewInt(0),
expectedY: big.NewInt(1),
expectedG: big.NewInt(0),
},
{
name: "Case 5",
a: big.NewInt(1),
b: big.NewInt(0),
expectedX: big.NewInt(1),
expectedY: big.NewInt(0),
expectedG: big.NewInt(1),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actualX, actualY, actualG := Igcdex(tt.a, tt.b)

if actualX.Cmp(tt.expectedX) != 0 {
t.Errorf("got x: %v, want: %v", actualX, tt.expectedX)
}
if actualY.Cmp(tt.expectedY) != 0 {
t.Errorf("got x: %v, want: %v", actualY, tt.expectedY)
}
if actualG.Cmp(tt.expectedG) != 0 {
t.Errorf("got x: %v, want: %v", actualG, tt.expectedG)
}
})
}
}
Loading

0 comments on commit f1f925d

Please sign in to comment.