From 637b2bfd9a04e039fa9d1205bdc530fb377eeeea Mon Sep 17 00:00:00 2001 From: Ilia Vlasov Date: Tue, 29 Aug 2023 15:12:59 +0100 Subject: [PATCH] Feature: constant time overflow checking while computing offsets (#29) --- pkg/safemath/safemath.go | 39 +++++++++++++++++++++++++++++++++++ pkg/safemath/safemath_test.go | 35 +++++++++++++++++++++++++++++++ pkg/vm/vm.go | 25 +++++++++++++++------- 3 files changed, 92 insertions(+), 7 deletions(-) create mode 100644 pkg/safemath/safemath.go create mode 100644 pkg/safemath/safemath_test.go diff --git a/pkg/safemath/safemath.go b/pkg/safemath/safemath.go new file mode 100644 index 000000000..98599fc2b --- /dev/null +++ b/pkg/safemath/safemath.go @@ -0,0 +1,39 @@ +package safemath + +import "math/bits" + +// Takes a uint64 and an int16 and outputs their addition as well +// as the ocurrence of an overflow or underflow. +// +// This is a constant-time version of the following function: +// +// func SafeOffset(x uint64, y int16) (res uint64, isOverflow bool) { +// res = x + uint64(y) +// if y < 0 { +// isOverflow = res >= x +// } else { +// isOverflow = res < x +// } +// return +// } +// +// This shows better results because the final bytecode +// doesn't contain any conditional jump instructions +// making it easier for a processor to pipeline the function. +func SafeOffset(x uint64, y int16) (res uint64, isOverflow bool) { + enlargedY := uint64(y) + // I'll leave proving that this is correct as an exercise for the reader :) + res = x + enlargedY + // Why does this work? + // Let's proceed by cases on the most significant bit of (MSB(x)). + // If MSB(x) == 1 and enlargedY < 0 (MSB(enlargedY) == 1) then overflow doesn't happen. + // Let's consider the case enlargedY >= 0 (MSB(enlargedY) == 0). + // In that case we can only wrap up by going to the begining of uint64 range making the MSB(res) = 0. + // This is the second disjunct of the disjunctive formula. + // + // In the same fashion the case MSB(x) == 0 and MSB(enlargedY) == 1 is reasoned about. + // + // Finally, we boil everything down to MSBs by rotating and anding with ...000001. + isOverflow = bits.RotateLeft64((^x&enlargedY&res)|(x & ^enlargedY & ^res), 1)&0x1 != 0 + return +} diff --git a/pkg/safemath/safemath_test.go b/pkg/safemath/safemath_test.go new file mode 100644 index 000000000..538c03435 --- /dev/null +++ b/pkg/safemath/safemath_test.go @@ -0,0 +1,35 @@ +package safemath + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOffsetNeg(t *testing.T) { + res, isOverflow := SafeOffset(1215, -3) + assert.Equal(t, uint64(1212), res) + assert.False(t, isOverflow) +} + +func TestOffsetPos(t *testing.T) { + res, isOverflow := SafeOffset(7, 11) + assert.Equal(t, uint64(18), res) + assert.False(t, isOverflow) +} + +func TestOffsetLeftOverflow(t *testing.T) { + _, isOverflow := SafeOffset(4, -10) + assert.True(t, isOverflow) +} + +func TestOffsetRightOverflow(t *testing.T) { + _, isOverflow := SafeOffset(^uint64(0), 1) + assert.True(t, isOverflow) +} + +func TestOffsetRightNoOverflow(t *testing.T) { + res, isOverflow := SafeOffset(^uint64(0), -12) + assert.Equal(t, uint64(18446744073709551603), res) + assert.False(t, isOverflow) +} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index dfe3c7eb5..70ff455cb 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -3,6 +3,7 @@ package vm import ( "fmt" + safemath "github.com/NethermindEth/cairo-vm-go/pkg/safemath" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -156,8 +157,11 @@ func (vm *VirtualMachine) getCellDst(instruction *Instruction) (*mem.Cell, error dstRegister = vm.Context.Fp } - // todo(rodro): fix this math - return vm.MemoryManager.Memory.Peek(executionSegment, dstRegister+uint64(instruction.OffDest)) + addr, isOverflow := safemath.SafeOffset(dstRegister, instruction.OffDest) + if isOverflow { + return nil, fmt.Errorf("integer overflow while appying offset: 0x%x %d", dstRegister, instruction.OffDest) + } + return vm.MemoryManager.Memory.Peek(executionSegment, addr) } func (vm *VirtualMachine) getCellOp0(instruction *Instruction) (*mem.Cell, error) { @@ -167,9 +171,12 @@ func (vm *VirtualMachine) getCellOp0(instruction *Instruction) (*mem.Cell, error } else { op0Register = vm.Context.Fp } - // todo(rodro): fix this math - offset := op0Register + uint64(instruction.OffOp0) - return vm.MemoryManager.Memory.Peek(executionSegment, offset) + + addr, isOverflow := safemath.SafeOffset(op0Register, instruction.OffOp0) + if isOverflow { + return nil, fmt.Errorf("integer overflow while appying offset: 0x%x %d", op0Register, instruction.OffOp0) + } + return vm.MemoryManager.Memory.Peek(executionSegment, addr) } func (vm *VirtualMachine) getCellOp1(instruction *Instruction, op0Cell *mem.Cell) (*mem.Cell, error) { @@ -189,8 +196,12 @@ func (vm *VirtualMachine) getCellOp1(instruction *Instruction, op0Cell *mem.Cell case ApPlusOffOp1: op1Address = mem.CreateMemoryAddress(executionSegment, vm.Context.Ap) } - // todo(rodro): fix this math - op1Address.Offset += uint64(instruction.OffOp1) + + addr, isOverflow := safemath.SafeOffset(op1Address.Offset, instruction.OffOp1) + if isOverflow { + return nil, fmt.Errorf("integer overflow while appying offset: 0x%x %d", op1Address.Offset, instruction.OffOp1) + } + op1Address.Offset = addr return vm.MemoryManager.Memory.PeekFromAddress(op1Address) }