From d2dfc2d873e688d5ce5fc655a00dc19ae88c7c43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carmen=20Irene=20Cabrera=20Rodr=C3=ADguez?= <49727740+cicr99@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:35:31 +0100 Subject: [PATCH] Hint reference parsing (#625) * add current failing tests * add new grammar * fix parsing of number as string * fix grammar to operate sub correctly * add explanation comments --------- Co-authored-by: Shourya Goel --- ...keccak_uint256s.starknet_with_keccak.cairo | 0 pkg/hintrunner/hinter/operand.go | 15 +- pkg/hintrunner/hinter/operand_test.go | 18 +- pkg/hintrunner/zero/hintparser.go | 441 +++++++++++------- pkg/hintrunner/zero/hintparser_test.go | 47 +- pkg/utils/math.go | 27 ++ 6 files changed, 348 insertions(+), 200 deletions(-) rename integration_tests/{cairo_zero_hint_tests_in_progress => cairo_zero_hint_tests}/keccak_uint256s.starknet_with_keccak.cairo (100%) diff --git a/integration_tests/cairo_zero_hint_tests_in_progress/keccak_uint256s.starknet_with_keccak.cairo b/integration_tests/cairo_zero_hint_tests/keccak_uint256s.starknet_with_keccak.cairo similarity index 100% rename from integration_tests/cairo_zero_hint_tests_in_progress/keccak_uint256s.starknet_with_keccak.cairo rename to integration_tests/cairo_zero_hint_tests/keccak_uint256s.starknet_with_keccak.cairo diff --git a/pkg/hintrunner/hinter/operand.go b/pkg/hintrunner/hinter/operand.go index a0663f837..3d57d9d25 100644 --- a/pkg/hintrunner/hinter/operand.go +++ b/pkg/hintrunner/hinter/operand.go @@ -147,12 +147,13 @@ type Operator uint8 const ( Add Operator = iota Mul + Sub ) type BinaryOp struct { Operator Operator - Lhs CellRefer - Rhs ResOperander // (except DoubleDeref and BinaryOp) + Lhs ResOperander + Rhs ResOperander } func (bop BinaryOp) String() string { @@ -160,13 +161,9 @@ func (bop BinaryOp) String() string { } func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) { - lhsAddr, err := bop.Lhs.Get(vm) + lhs, err := bop.Lhs.Resolve(vm) if err != nil { - return mem.UnknownValue, fmt.Errorf("get lhs address %s: %w", bop.Lhs, err) - } - lhs, err := vm.Memory.ReadFromAddress(&lhsAddr) - if err != nil { - return mem.UnknownValue, fmt.Errorf("read lhs address %s: %w", lhsAddr, err) + return mem.UnknownValue, fmt.Errorf("resolve lhs operand %s: %w", lhs, err) } rhs, err := bop.Rhs.Resolve(vm) @@ -221,7 +218,7 @@ func (v DoubleDeref) ApplyApTracking(hint, ref zero.ApTracking) Reference { } func (v BinaryOp) ApplyApTracking(hint, ref zero.ApTracking) Reference { - v.Lhs = v.Lhs.ApplyApTracking(hint, ref).(CellRefer) + v.Lhs = v.Lhs.ApplyApTracking(hint, ref).(ResOperander) v.Rhs = v.Rhs.ApplyApTracking(hint, ref).(ResOperander) return v } diff --git a/pkg/hintrunner/hinter/operand_test.go b/pkg/hintrunner/hinter/operand_test.go index 55d4fcb1d..be386e5e6 100644 --- a/pkg/hintrunner/hinter/operand_test.go +++ b/pkg/hintrunner/hinter/operand_test.go @@ -127,19 +127,20 @@ func TestResolveAddOp(t *testing.T) { memory.MemoryValueFromInt(30), ) - // lhs + // Lhs var ap ApCellRef = 7 + lhs := Deref{ap} // Rhs var fp FpCellRef = 20 - deref := Deref{fp} + rhs := Deref{fp} operator := Add bop := BinaryOp{ Operator: operator, - Lhs: ap, - Rhs: deref, + Lhs: lhs, + Rhs: rhs, } res, err := bop.Resolve(vm) @@ -164,19 +165,20 @@ func TestResolveMulOp(t *testing.T) { memory.MemoryValueFromInt(5), ) - // lhs + // Lhs var ap ApCellRef = 7 + lhs := Deref{ap} // Rhs var fp FpCellRef = 20 - deref := Deref{fp} + rhs := Deref{fp} operator := Mul bop := BinaryOp{ Operator: operator, - Lhs: ap, - Rhs: deref, + Lhs: lhs, + Rhs: rhs, } res, err := bop.Resolve(vm) diff --git a/pkg/hintrunner/zero/hintparser.go b/pkg/hintrunner/zero/hintparser.go index fe3b28e19..40ae81551 100644 --- a/pkg/hintrunner/zero/hintparser.go +++ b/pkg/hintrunner/zero/hintparser.go @@ -2,31 +2,41 @@ package zero import ( "fmt" + "math/big" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/alecthomas/participle/v2" + "github.com/alecthomas/participle/v2/lexer" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -var parser *participle.Parser[IdentifierExp] = participle.MustBuild[IdentifierExp](participle.UseLookahead(10)) - -// Possible cases extracted from https://github.com/lambdaclass/cairo-vm_in_go/blob/main/pkg/hints/hint_utils/hint_reference.go#L41 -// Immediate: cast(number, type) -// Reference no deref 1 offset: cast(reg + off, type) -// Reference no deref 2 offsets: cast(reg + off1 + off2, type) -// Reference with deref 1 offset: cast([reg + off1], type) -// Reference with deref 2 offsets: cast([reg + off1] + off2, type) -// Two references with deref: cast([reg + off1] + [reg + off2], type) -// Reference off omitted: cast(reg, type) -// Reference with deref off omitted: cast([reg], type) -// Reference with deref 2 offsets off1 omitted: cast([reg] + off2, type) -// 2 dereferences off1 omitted: cast([reg] + [reg + off2], type) -// 2 dereferences off2 omitted: cast([reg + off1] + [reg], type) -// 2 dereferences both offs omitted: cast([reg] + [reg], type) -// 2 dereferences with multiplication: cast([reg + off1] * [reg + off2], felt) -// Reference no dereference 2 offsets - + : cast(reg - off1 + off2, type) - -// Note: The same cases apply with an external dereference. Example: [cast(number, type)] +// Hint references follow the format "cast(, )". It also allows an +// external dereference such as "[cast(, )]". The in +// the hint reference is interpreted as an arithmetic expression, so the root of the +// grammar defined in this file would be `arithExp` +// +// Grammar: +// arithExp => term (('+'|'-') term)* +// term => exp | prodExp +// prodExp => exp '*' exp +// exp => cellRef | deref | dderef | int +// cellRef => ('ap'|'fp') ('+'|'-') int +// deref => [cellRef] +// dderef => [deref ('+'|'-') int] + +var ( + basicLexer = lexer.MustSimple([]lexer.SimpleRule{ + {"Number", `\d+`}, + {"Ident", `[a-zA-Z_]\w*`}, + {"Punct", `[-[!@#$%^&*()+_={}\|:;"'<,>.?/]|]`}, + {"whitespace", `[ \t]+`}, + }) + parser = participle.MustBuild[IdentifierExp]( + participle.Lexer(basicLexer), + participle.UseLookahead(20), + ) +) type IdentifierExp struct { DerefCastExp *DerefCastExp `@@ |` @@ -38,61 +48,77 @@ type DerefCastExp struct { } type CastExp struct { - ValueExpr *Expression `"cast" "(" @@ ","` - CastType []string `@Ident ("." @Ident)* ("*")? ("*")? ")"` + ValueExp *ArithExp `"cast" "(" @@ ","` + CastType []string `@Ident ("." @Ident)* ("*")? ("*")? ")"` +} + +type ArithExp struct { + TermExp *TermExp `@@` + AddExp []AddExp `@@*` +} + +type AddExp struct { + Operator string `@("+" | "-")` + TermExp *TermExp `@@` +} + +type TermExp struct { + ProdExp *ProdExp `@@ |` + Exp *Expression `@@` +} + +type ProdExp struct { + LeftExp *Expression `@@` + Operator string `"*"` + RightExp *Expression `@@` } type Expression struct { - BinOpExp *BinOpExp `@@ |` - CellRefExp *CellRefExp `"(" @@ ")" | @@ |` - DerefExp *DerefExp `@@` + DDerefExp *DDerefExp `@@ |` + DerefExp *DerefExp `@@ |` + CellRefExp *CellRefExp `@@ |` + IntExp *OffsetExp `@@` } -type CellRefExp struct { +// CellRefSimple represents the structure of a CellRef in its natural form. +// A CellRefSimple cannot be an Expression by itself if it has an offset, +// since the parser will interpret this as a sum of terms instead. +// That's why CellRefExp is also defined. Notice that in the case where there +// is an offset, the whole expression is expected to be enclosed in parenthesis. +type CellRefSimple struct { RegisterOffset *RegisterOffset `@@ |` Register string `@("ap" | "fp")` } +type CellRefExp struct { + RegisterOffset *RegisterOffset `"(" @@ ")" |` + Register string `@("ap" | "fp")` +} + type RegisterOffset struct { Register string `@("ap" | "fp")` Operator string `@("+" | "-")` Offset *OffsetExp `@@` } -type DerefExp struct { - CellRefExp *CellRefExp `"[" @@ "]"` -} - -type BinOpExp struct { - LeftExp *LeftExp `@@` - Operator string `@("+" | "*")` - RightExp *RightExp `@@` -} - type OffsetExp struct { - Number *int `@Int |` - NegNumber *int `"(" "-" @Int ")"` + Number string `@Number |` + NegNumber string `"(" "-" @Number ")"` } -type LeftExp struct { - CellRefExp *RegisterOffset `@@ |` - DerefExp *DerefExp `@@` +type DerefExp struct { + CellRefExp *CellRefSimple `"[" @@ "]"` } -type RightExp struct { - DerefExp *DerefExp `@@ |` +type DerefOffsetExp struct { + DerefExp *DerefExp `@@` + Operator string `@("+" | "-")` Offset *OffsetExp `@@` } -type DerefOffset struct { - Deref hinter.Deref - Op hinter.Operator - Offset *int -} -type DerefDeref struct { - LeftDeref hinter.Deref - Op hinter.Operator - RightDeref hinter.Deref +type DDerefExp struct { + DerefOffsetExp *DerefOffsetExp `"[" @@ "]" |` + DerefExp *DerefExp `"[" @@ "]"` } // AST Functionality @@ -108,7 +134,7 @@ func (expression IdentifierExp) Evaluate() (hinter.Reference, error) { } func (expression DerefCastExp) Evaluate() (hinter.Reference, error) { - value, err := expression.CastExp.ValueExpr.Evaluate() + value, err := expression.CastExp.ValueExp.Evaluate() if err != nil { return nil, err } @@ -122,73 +148,159 @@ func (expression DerefCastExp) Evaluate() (hinter.Reference, error) { Offset: 0, }, nil - case DerefOffset: - return hinter.DoubleDeref{ - Deref: hinter.Deref{ - Deref: result.Deref.Deref, - }, - Offset: int16(*result.Offset), - }, - nil + case hinter.BinaryOp: + if left, ok := result.Lhs.(hinter.Deref); ok { + if right, ok := result.Rhs.(hinter.Immediate); ok { + if offset, ok := utils.Int16FromFelt((*fp.Element)(&right)); ok { + return hinter.DoubleDeref{ + Deref: left, + Offset: offset, + }, + nil + } + } + } + return nil, fmt.Errorf("invalid binary operation inside a deref") default: - return nil, fmt.Errorf("unexpected identifier value") + return nil, fmt.Errorf("unexpected deref expression") } } func (expression CastExp) Evaluate() (hinter.Reference, error) { - value, err := expression.ValueExpr.Evaluate() + return expression.ValueExp.Evaluate() +} + +func (expression ArithExp) Evaluate() (hinter.Reference, error) { + leftExp, err := expression.TermExp.Evaluate() if err != nil { return nil, err } - switch result := value.(type) { - case hinter.CellRefer: - return result, nil - case hinter.Deref: - return result, nil - case DerefOffset: - rhsFelt := fp.NewElement(uint64(*result.Offset)) - return hinter.BinaryOp{ - Operator: result.Op, - Lhs: result.Deref.Deref, - // TODO: why we're not using something like f.NewElement here? - Rhs: hinter.Immediate(rhsFelt), - }, nil - case DerefDeref: - return hinter.BinaryOp{ - Operator: result.Op, - Lhs: result.LeftDeref.Deref, - Rhs: result.RightDeref, - }, nil + if leftResult, ok := leftExp.(hinter.CellRefer); ok { + // Binary Operation does not support CellRef in the left hand side + // so the expression has to follow the pattern: + // reg + off + off + ... + off + for _, term := range expression.AddExp { + rightExp, err := term.TermExp.Evaluate() + if err != nil { + return nil, err + } + rightResult, ok := rightExp.(hinter.Immediate) + if !ok { + return nil, fmt.Errorf("invalid arithmetic expression") + } + + off, ok := utils.Int16FromFelt((*fp.Element)(&rightResult)) + if !ok { + return nil, fmt.Errorf("invalid arithmetic expression") + } + + if term.Operator == "-" { + off = -off + } + + switch cellRef := leftResult.(type) { + case hinter.ApCellRef: + oldOffset := int16(cellRef) + leftResult = hinter.ApCellRef(off + oldOffset) + continue + case hinter.FpCellRef: + oldOffset := int16(cellRef) + leftResult = hinter.FpCellRef(off + oldOffset) + continue + } + } + return leftResult, nil + } else { + for _, term := range expression.AddExp { + rightExp, err := term.TermExp.Evaluate() + if err != nil { + return nil, err + } + + op, err := parseOperator(term.Operator) + if err != nil { + return nil, err + } + + if leftResult, ok := leftExp.(hinter.ResOperander); ok { + if rightResult, ok := rightExp.(hinter.ResOperander); ok { + leftExp = hinter.BinaryOp{ + Operator: op, + Lhs: leftResult, + Rhs: rightResult, + } + continue + } + } + return nil, fmt.Errorf("invalid arithmetic expression") + } + return leftExp, nil + } + +} + +func (expression TermExp) Evaluate() (hinter.Reference, error) { + switch { + case expression.ProdExp != nil: + return expression.ProdExp.Evaluate() + case expression.Exp != nil: + return expression.Exp.Evaluate() default: return nil, fmt.Errorf("unexpected identifier value") } } -func (expression Expression) Evaluate() (any, error) { +func (expression ProdExp) Evaluate() (hinter.Reference, error) { + leftExp, err := expression.LeftExp.Evaluate() + if err != nil { + return nil, err + } + rightExp, err := expression.RightExp.Evaluate() + if err != nil { + return nil, err + } + + if leftOp, ok := leftExp.(hinter.ResOperander); ok { + if rightOp, ok := rightExp.(hinter.ResOperander); ok { + return hinter.BinaryOp{ + Operator: hinter.Mul, + Lhs: leftOp, + Rhs: rightOp, + }, nil + } + } + return nil, fmt.Errorf("unexpected product expression") +} + +func (expression Expression) Evaluate() (hinter.Reference, error) { switch { + case expression.IntExp != nil: + intExp, err := expression.IntExp.Evaluate() + if err != nil { + return nil, err + } + return hinter.Immediate(*new(fp.Element).SetBigInt(intExp)), nil case expression.CellRefExp != nil: return expression.CellRefExp.Evaluate() case expression.DerefExp != nil: return expression.DerefExp.Evaluate() - case expression.BinOpExp != nil: - return expression.BinOpExp.Evaluate() + case expression.DDerefExp != nil: + return expression.DDerefExp.Evaluate() default: return nil, fmt.Errorf("unexpected expression value") } } -func (expression RegisterOffset) Evaluate() (any, error) { - offsetValue, _ := expression.Offset.Evaluate() - offset := int16(*offsetValue) - if expression.Operator == "-" { - offset = -offset +func (expression CellRefSimple) Evaluate() (hinter.CellRefer, error) { + if expression.RegisterOffset != nil { + return expression.RegisterOffset.Evaluate() } - return EvaluateRegister(expression.Register, offset) + return EvaluateRegister(expression.Register, 0) } -func (expression CellRefExp) Evaluate() (any, error) { +func (expression CellRefExp) Evaluate() (hinter.CellRefer, error) { if expression.RegisterOffset != nil { return expression.RegisterOffset.Evaluate() } @@ -196,6 +308,19 @@ func (expression CellRefExp) Evaluate() (any, error) { return EvaluateRegister(expression.Register, 0) } +func (expression RegisterOffset) Evaluate() (hinter.CellRefer, error) { + offsetValue, _ := expression.Offset.Evaluate() + offset, ok := utils.Int16FromBigInt(offsetValue) + if !ok { + return nil, fmt.Errorf("offset does not fit in int16") + } + if expression.Operator == "-" { + offset = -offset + } + + return EvaluateRegister(expression.Register, offset) +} + func EvaluateRegister(register string, offset int16) (hinter.CellRefer, error) { switch register { case "ap": @@ -207,111 +332,69 @@ func EvaluateRegister(register string, offset int16) (hinter.CellRefer, error) { } } -func (expression OffsetExp) Evaluate() (*int, error) { +func (expression OffsetExp) Evaluate() (*big.Int, error) { switch { - case expression.Number != nil: - return expression.Number, nil - case expression.NegNumber != nil: - negNumber := -*expression.NegNumber - return &negNumber, nil + case expression.Number != "": + bigIntValue, ok := new(big.Int).SetString(expression.Number, 10) + if !ok { + return nil, fmt.Errorf("expected a number") + } + return bigIntValue, nil + case expression.NegNumber != "": + bigIntValue, ok := new(big.Int).SetString(expression.NegNumber, 10) + if !ok { + return nil, fmt.Errorf("expected a number") + } + negNumber := bigIntValue.Neg(bigIntValue) + return negNumber, nil default: return nil, fmt.Errorf("expected a number") } } -func (expression DerefExp) Evaluate() (any, error) { - cellRefExp, err := expression.CellRefExp.Evaluate() +func (expression DerefExp) Evaluate() (hinter.Deref, error) { + cellRef, err := expression.CellRefExp.Evaluate() if err != nil { - return nil, err - } - cellRef, ok := cellRefExp.(hinter.CellRefer) - if !ok { - return nil, fmt.Errorf("expected a CellRefer expression but got %s", cellRefExp) + return hinter.Deref{}, err } return hinter.Deref{Deref: cellRef}, nil } -func (expression BinOpExp) Evaluate() (any, error) { - leftExp, err := expression.LeftExp.Evaluate() - if err != nil { - return nil, err - } - - rightExp, err := expression.RightExp.Evaluate() - if err != nil { - return nil, err - } - - operation, err := parseOperator(expression.Operator) - if err != nil { - return nil, err - } - - switch lResult := leftExp.(type) { - case hinter.CellRefer: - // Right now we assume that there is no expression like `reg - off1 * off2`, - // but if there are, we would need to come up with an idea how to handle it. - // Right now we only cover `off1 + off2` expressions here. - offset, ok := rightExp.(*int) - if !ok { - return nil, fmt.Errorf("invalid type operation") +func (expression DDerefExp) Evaluate() (hinter.DoubleDeref, error) { + switch { + case expression.DerefExp != nil: + derefExp, err := expression.DerefExp.Evaluate() + if err != nil { + return hinter.DoubleDeref{}, err } - offsetValue := int16(*offset) - - var cellRefOffset int16 - switch register := lResult.(type) { - case hinter.ApCellRef: - cellRefOffset = int16(register) - case hinter.FpCellRef: - cellRefOffset = int16(register) + return hinter.DoubleDeref{ + Deref: derefExp, + Offset: 0, + }, nil + case expression.DerefOffsetExp != nil: + derefExp, err := expression.DerefOffsetExp.DerefExp.Evaluate() + if err != nil { + return hinter.DoubleDeref{}, err } - - offsetValue = offsetValue + cellRefOffset - switch lResult.(type) { - case hinter.ApCellRef: - return hinter.ApCellRef(offsetValue), nil - case hinter.FpCellRef: - return hinter.FpCellRef(offsetValue), nil + offsetValue, err := expression.DerefOffsetExp.Offset.Evaluate() + if err != nil { + return hinter.DoubleDeref{}, err } - - case hinter.Deref: - switch rResult := rightExp.(type) { - case hinter.Deref: - return DerefDeref{ - lResult, - operation, - rResult, - }, nil - case *int: - return DerefOffset{ - lResult, - operation, - rResult, - }, nil + offset, ok := utils.Int16FromBigInt(offsetValue) + if !ok { + return hinter.DoubleDeref{}, fmt.Errorf("offset does not fit in int16") } - } - - return nil, fmt.Errorf("invalid binary operation") -} - -func (expression LeftExp) Evaluate() (any, error) { - switch { - case expression.CellRefExp != nil: - return expression.CellRefExp.Evaluate() - case expression.DerefExp != nil: - return expression.DerefExp.Evaluate() - } - return nil, fmt.Errorf("unexpected left expression in binary operation") -} + if expression.DerefOffsetExp.Operator == "-" { + offset = -offset + } + return hinter.DoubleDeref{ + Deref: derefExp, + Offset: offset, + }, nil -func (expression RightExp) Evaluate() (any, error) { - switch { - case expression.DerefExp != nil: - return expression.DerefExp.Evaluate() - case expression.Offset != nil: - return expression.Offset.Evaluate() + default: + return hinter.DoubleDeref{}, fmt.Errorf("unexpected double deref expression") } - return nil, fmt.Errorf("unexpected right expression in binary operation") } func ParseIdentifier(value string) (hinter.Reference, error) { diff --git a/pkg/hintrunner/zero/hintparser_test.go b/pkg/hintrunner/zero/hintparser_test.go index e8588ddef..260fc3516 100644 --- a/pkg/hintrunner/zero/hintparser_test.go +++ b/pkg/hintrunner/zero/hintparser_test.go @@ -42,7 +42,9 @@ func TestHintParser(t *testing.T) { ExpectedCellRefer: nil, ExpectedResOperander: hinter.BinaryOp{ Operator: hinter.Add, - Lhs: hinter.ApCellRef(2), + Lhs: hinter.Deref{ + Deref: hinter.ApCellRef(2), + }, Rhs: hinter.Deref{ Deref: hinter.ApCellRef(0), }, @@ -53,7 +55,9 @@ func TestHintParser(t *testing.T) { ExpectedCellRefer: nil, ExpectedResOperander: hinter.BinaryOp{ Operator: hinter.Mul, - Lhs: hinter.ApCellRef(-5), + Lhs: hinter.Deref{ + Deref: hinter.ApCellRef(-5), + }, Rhs: hinter.Deref{ Deref: hinter.ApCellRef(-1), }, @@ -64,8 +68,43 @@ func TestHintParser(t *testing.T) { ExpectedCellRefer: nil, ExpectedResOperander: hinter.BinaryOp{ Operator: hinter.Mul, - Lhs: hinter.ApCellRef(0), - Rhs: hinter.Immediate{18446744073709551521, 18446744073709551615, 18446744073709551615, 576460752303421872}, + Lhs: hinter.Deref{ + Deref: hinter.ApCellRef(0), + }, + Rhs: hinter.Immediate{18446744073709551521, 18446744073709551615, 18446744073709551615, 576460752303421872}, + }, + }, + { + Parameter: "cast(2389472938759290879897, felt)", + ExpectedCellRefer: nil, + ExpectedResOperander: hinter.Immediate(*feltString("2389472938759290879897")), + }, + { + Parameter: "cast([[ap + 2] + (-5)], felt)", + ExpectedCellRefer: nil, + ExpectedResOperander: hinter.DoubleDeref{ + Deref: hinter.Deref{ + Deref: hinter.ApCellRef(2), + }, + Offset: int16(-5), + }, + }, + { + Parameter: "cast([fp + (-4)] * 18, felt)", + ExpectedCellRefer: nil, + ExpectedResOperander: hinter.BinaryOp{ + Operator: hinter.Mul, + Lhs: hinter.Deref{ + Deref: hinter.FpCellRef(-4), + }, + Rhs: hinter.Immediate(*feltInt64(18)), + }, + }, + { + Parameter: "[cast(ap - 0 + (-1), felt*)]", + ExpectedCellRefer: nil, + ExpectedResOperander: hinter.Deref{ + Deref: hinter.ApCellRef(-1), }, }, } diff --git a/pkg/utils/math.go b/pkg/utils/math.go index 4f76a616b..ea02fd0ca 100644 --- a/pkg/utils/math.go +++ b/pkg/utils/math.go @@ -1,6 +1,7 @@ package utils import ( + "math" "math/big" "math/bits" @@ -126,6 +127,32 @@ func FeltDivRem(a, b *fp.Element) (div fp.Element, rem fp.Element) { return div, rem } +func Int16FromFelt(n *fp.Element) (int16, bool) { + bigN := n.BigInt(new(big.Int)) + return Int16FromBigInt(bigN) +} + +func Int16FromBigInt(n *big.Int) (int16, bool) { + mod := fp.Modulus() + negN := new(big.Int).Sub(mod, n) + maxInt16 := new(big.Int).SetInt64(int64(math.MaxInt16)) + + var result int64 + if n.Cmp(negN) == 1 { + if negN.Cmp(maxInt16) == 1 { + return 0, false + } + result = -negN.Int64() + } else { + if n.Cmp(maxInt16) == 1 { + return 0, false + } + result = n.Int64() + } + + return int16(result), true +} + func RightRot(value uint32, n uint32) uint32 { return (value >> n) | ((value & ((1 << n) - 1)) << (32 - n)) }