Skip to content

Commit

Permalink
Feature: constant time overflow checking while computing offsets (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
ElijahVlasov authored Aug 29, 2023
1 parent f0a851d commit 637b2bf
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 7 deletions.
39 changes: 39 additions & 0 deletions pkg/safemath/safemath.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package safemath

import "math/bits"

// Takes a uint64 and an int16 and outputs their addition as well
// as the ocurrence of an overflow or underflow.
//
// This is a constant-time version of the following function:
//
// func SafeOffset(x uint64, y int16) (res uint64, isOverflow bool) {
// res = x + uint64(y)
// if y < 0 {
// isOverflow = res >= x
// } else {
// isOverflow = res < x
// }
// return
// }
//
// This shows better results because the final bytecode
// doesn't contain any conditional jump instructions
// making it easier for a processor to pipeline the function.
func SafeOffset(x uint64, y int16) (res uint64, isOverflow bool) {
enlargedY := uint64(y)
// I'll leave proving that this is correct as an exercise for the reader :)
res = x + enlargedY
// Why does this work?
// Let's proceed by cases on the most significant bit of (MSB(x)).
// If MSB(x) == 1 and enlargedY < 0 (MSB(enlargedY) == 1) then overflow doesn't happen.
// Let's consider the case enlargedY >= 0 (MSB(enlargedY) == 0).
// In that case we can only wrap up by going to the begining of uint64 range making the MSB(res) = 0.
// This is the second disjunct of the disjunctive formula.
//
// In the same fashion the case MSB(x) == 0 and MSB(enlargedY) == 1 is reasoned about.
//
// Finally, we boil everything down to MSBs by rotating and anding with ...000001.
isOverflow = bits.RotateLeft64((^x&enlargedY&res)|(x & ^enlargedY & ^res), 1)&0x1 != 0
return
}
35 changes: 35 additions & 0 deletions pkg/safemath/safemath_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package safemath

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestOffsetNeg(t *testing.T) {
res, isOverflow := SafeOffset(1215, -3)
assert.Equal(t, uint64(1212), res)
assert.False(t, isOverflow)
}

func TestOffsetPos(t *testing.T) {
res, isOverflow := SafeOffset(7, 11)
assert.Equal(t, uint64(18), res)
assert.False(t, isOverflow)
}

func TestOffsetLeftOverflow(t *testing.T) {
_, isOverflow := SafeOffset(4, -10)
assert.True(t, isOverflow)
}

func TestOffsetRightOverflow(t *testing.T) {
_, isOverflow := SafeOffset(^uint64(0), 1)
assert.True(t, isOverflow)
}

func TestOffsetRightNoOverflow(t *testing.T) {
res, isOverflow := SafeOffset(^uint64(0), -12)
assert.Equal(t, uint64(18446744073709551603), res)
assert.False(t, isOverflow)
}
25 changes: 18 additions & 7 deletions pkg/vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package vm
import (
"fmt"

safemath "github.com/NethermindEth/cairo-vm-go/pkg/safemath"
mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)
Expand Down Expand Up @@ -156,8 +157,11 @@ func (vm *VirtualMachine) getCellDst(instruction *Instruction) (*mem.Cell, error
dstRegister = vm.Context.Fp
}

// todo(rodro): fix this math
return vm.MemoryManager.Memory.Peek(executionSegment, dstRegister+uint64(instruction.OffDest))
addr, isOverflow := safemath.SafeOffset(dstRegister, instruction.OffDest)
if isOverflow {
return nil, fmt.Errorf("integer overflow while appying offset: 0x%x %d", dstRegister, instruction.OffDest)
}
return vm.MemoryManager.Memory.Peek(executionSegment, addr)
}

func (vm *VirtualMachine) getCellOp0(instruction *Instruction) (*mem.Cell, error) {
Expand All @@ -167,9 +171,12 @@ func (vm *VirtualMachine) getCellOp0(instruction *Instruction) (*mem.Cell, error
} else {
op0Register = vm.Context.Fp
}
// todo(rodro): fix this math
offset := op0Register + uint64(instruction.OffOp0)
return vm.MemoryManager.Memory.Peek(executionSegment, offset)

addr, isOverflow := safemath.SafeOffset(op0Register, instruction.OffOp0)
if isOverflow {
return nil, fmt.Errorf("integer overflow while appying offset: 0x%x %d", op0Register, instruction.OffOp0)
}
return vm.MemoryManager.Memory.Peek(executionSegment, addr)
}

func (vm *VirtualMachine) getCellOp1(instruction *Instruction, op0Cell *mem.Cell) (*mem.Cell, error) {
Expand All @@ -189,8 +196,12 @@ func (vm *VirtualMachine) getCellOp1(instruction *Instruction, op0Cell *mem.Cell
case ApPlusOffOp1:
op1Address = mem.CreateMemoryAddress(executionSegment, vm.Context.Ap)
}
// todo(rodro): fix this math
op1Address.Offset += uint64(instruction.OffOp1)

addr, isOverflow := safemath.SafeOffset(op1Address.Offset, instruction.OffOp1)
if isOverflow {
return nil, fmt.Errorf("integer overflow while appying offset: 0x%x %d", op1Address.Offset, instruction.OffOp1)
}
op1Address.Offset = addr

return vm.MemoryManager.Memory.PeekFromAddress(op1Address)
}
Expand Down

0 comments on commit 637b2bf

Please sign in to comment.