Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nvidia MSM proof of concept (serial) #480

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cacb22d
wrap `execCudaImpl` macro logic in a block
Vindaar Nov 5, 2024
d4e640c
add more EC Jac operations to helper templates
Vindaar Nov 5, 2024
22c0565
do not quit on failure in NvidiaAssembler destructor
Vindaar Nov 5, 2024
c1257ac
add CurveDescriptor fields for LLVM type for Fr, scalars for MSM
Vindaar Nov 5, 2024
43e4d19
[LLVM] add `isPointerTy` helper to determine if type is a pointer
Vindaar Nov 5, 2024
6cd3ca8
[tests] add sanity test for adding neutral EC element to EC sum
Vindaar Nov 5, 2024
d8b21c5
store EC order bit width in CurveDescriptor
Vindaar Nov 5, 2024
7a786ef
make `store` for `ValueRef` safer by checking for pointer-ness
Vindaar Nov 5, 2024
9ee8fe5
forbid `=copy` on Array, likely *not* what user wants
Vindaar Nov 5, 2024
494e4ca
allow access read/write of `Array` using `ValueRef`
Vindaar Nov 5, 2024
ec28afc
add `FieldScalar`, `FieldScalarArray`, `EcAffArray`, `EcAffArray`
Vindaar Nov 5, 2024
21fb88c
extend doc string of `compile` taking a string
Vindaar Nov 5, 2024
cf095cf
add ConstantValue, MutableValue wrappers around ValueRef
Vindaar Nov 5, 2024
5d7f03d
add `llvmFor` macro that produces code for a for loop in LLVM
Vindaar Nov 5, 2024
9a0f8eb
add helpers for arithmetic, boolean logic for ValueRef, M/CValue
Vindaar Nov 5, 2024
c67548c
add `llvmIf` to generate code for if statements
Vindaar Nov 5, 2024
0b28232
add `to` type conversion helper which extends/truncates int types
Vindaar Nov 5, 2024
085b233
use `llvmForCountdown` in `genFpNsqrRt` instead of fixed countdown logic
Vindaar Nov 5, 2024
bdf667d
add `getWindowAt` helper required for baseline MSM implementation
Vindaar Nov 5, 2024
44ce9df
add serial MSM implementation for Nvidia using bucket method
Vindaar Nov 5, 2024
6eb0c60
[tests] add mini test case for MSM on Nvidia
Vindaar Nov 5, 2024
83e603a
whoops, revert local change to test CT error on `=copy`
Vindaar Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions constantine/math_compiler/codegen_nvidia.nim
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ export
# Cuda Driver API
# ------------------------------------------------------------

template check*(status: CUresult) =
template check*(status: CUresult, quitOnFailure = true) =
## Check the status code of a CUDA operation
## Exit program with error if failure

let code = status # ensure that the input expression is evaluated once only

if code != CUDA_SUCCESS:
writeStackTrace()
stderr.write(astToStr(status) & " " & $instantiationInfo() & " exited with error: " & $code & '\n')
quit 1
if quitOnFailure:
quit 1 # NOTE: this hides exceptions if they are thrown!

func cuModuleLoadData*(module: var CUmodule, sourceCode: openArray[char]): CUresult {.inline.}=
cuModuleLoadData(module, sourceCode[0].unsafeAddr)
Expand Down Expand Up @@ -448,6 +450,9 @@ proc execCudaImpl(jitFn, res, inputs: NimNode): NimNode =
x[0]
)
)
result = quote do:
block:
`result`

macro execCuda*(jitFn: CUfunction,
res: typed,
Expand Down Expand Up @@ -513,8 +518,18 @@ type

proc `=destroy`*(nv: NvidiaAssemblerObj) =
## XXX: Need to also call the finalizer for `asy` in the future!
check nv.cuMod.cuModuleUnload()
check nv.cuCtx.cuCtxDestroy()
# NOTE: In the destructor we don't want to quit on a `check` failure.
# The reason is that if we throw an exception with an `NvidiaAssembler`
# in scope, it will trigger the destructor here (with a likely invalid
# state in the CUDA module / context). However, in this case
# we will crash anyway and would just end up hiding the actual cause of
# the error.
# In the unlikely case that all CUDA operations worked correctly up
# to this point, but then fail to unload, we currently ignore this
# as a failure mode.
# Hopefully we find a better solution in the future.
check nv.cuMod.cuModuleUnload(), quitOnFailure = false
check nv.cuCtx.cuCtxDestroy(), quitOnFailure = false
`=destroy`(nv.asy)

proc initNvAsm*[Name: static Algebra](field: type FF[Name], wordSize: int = 32, backend = bkNvidiaPTX): NvidiaAssembler =
Expand Down Expand Up @@ -571,7 +586,8 @@ proc initNvAsm*[Name: static Algebra](field: type EC_ShortW_Jac[Fp[Name], G1], w
Fp[Name].getModulus().toHex(),
v = 1, w = wordSize,
coef_a = Fp[Name].Name.getCoefA(),
coef_B = Fp[Name].Name.getCoefB()
coef_B = Fp[Name].Name.getCoefB(),
curveOrderBitWidth = Fr[Name].bits()
)
result.fd = result.cd.fd
result.asy.definePrimitives(result.cd)
Expand All @@ -580,6 +596,19 @@ proc compile*(nv: NvidiaAssembler, kernName: string): CUfunction =
## Overload of `compile` below.
## Call this version if you have manually used the Assembler_LLVM object
## to build instructions and have a kernel name you wish to compile.
##
## Use this overload if your generator function does not match the `FieldFnGenerator` or
## `CurveFnGenerator` signatures. This is useful if your function requires additional
## arguments that are compile time values in the context of LLVM.
##
## Example:
##
## ```nim
## let nv = initNvAsm(EC, wordSize)
## let kernel = nv.compile(asy.genEcMSM(cd, 3, 1000) # window size, num. points
## ```
## where `genEcMSM` returns the name of the kernel.

let ptx = nv.asy.codegenNvidiaPTX(nv.sm) # convert to PTX

# GPU exec
Expand Down
29 changes: 29 additions & 0 deletions constantine/math_compiler/impl_curves_ops_affine.nim
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ const SectionName = "ctt.curves_affine"
type
EcPointAff* {.borrow: `.`.} = distinct Array

proc `=copy`*(m: var EcPointAff, x: EcPointAff) {.error: "Copying an EcPointAff is not allowed. " &
"You likely want to copy the LLVM value. Use `dst.store(src)` instead.".}

proc asEcPointAff*(br: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointAff =
## Constructs an elliptic curve point in Affine coordinates from an array pointer.
##
## `arrayTy` is an `array[FieldTy, 2]` where `FieldTy` itsel is an array of
## `array[WordTy, NumWords]`.
result = EcPointAff(br.asArray(arrayPtr, arrayTy))

proc asEcPointAff*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointAff =
## Constructs an elliptic curve point in Affine coordinates from an array pointer.
##
Expand Down Expand Up @@ -54,6 +64,25 @@ proc store*(dst: EcPointAff, src: EcPointAff) =
store(dst.getX(), src.getX())
store(dst.getY(), src.getY())

# Array of EC points in affine coordinates
type EcAffArray* {.borrow: `.`.} = distinct Array

proc `=copy`(m: var EcAffArray, x: EcAffArray) {.error: "Copying an EcAffArray is not allowed. " &
"You likely want to copy the LLVM value. Use `dst.store(src)` instead.".}

proc `[]`*(a: EcAffArray, index: SomeInteger | ValueRef): EcPointAff = a.builder.asEcPointAff((distinctBase(a).getPtr(index)), a.elemTy)
proc `[]=`*(a: EcAffArray, index: SomeInteger | ValueRef, val: EcPointAff) = distinctBase(a)[index] = val.buf

proc asEcAffArray*(asy: Assembler_LLVM, cd: CurveDescriptor, a: ValueRef, num: int): EcAffArray =
## Interpret the given value `a` as an array of EC elements in Affine coordinates.
let ty = array_t(cd.curveTyAff, num)
result = EcAffArray(asy.br.asArray(a, ty))

proc initEcAffArray*(asy: Assembler_LLVM, cd: CurveDescriptor, num: int): EcAffArray =
## Initialize a new EcAffArray for `num` elements
let ty = array_t(cd.curveTyAff, num)
result = EcAffArray(asy.makeArray(ty))

template declEllipticAffOps*(asy: Assembler_LLVM, cd: CurveDescriptor): untyped =
## This template can be used to make operations on `Field` elements
## more convenient.
Expand Down
43 changes: 42 additions & 1 deletion constantine/math_compiler/impl_curves_ops_jacobian.nim
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ const SectionName = "ctt.curves_jacobian"
type
EcPointJac* {.borrow: `.`.} = distinct Array

proc `=copy`(m: var EcPointJac, x: EcPointJac) {.error: "Copying an EcPointJac is not allowed. " &
"You likely want to copy the LLVM value. Use `dst.store(src)` instead.".}

proc asEcPointJac*(br: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointJac =
## Constructs an elliptic curve point in Jacobian coordinates from an array pointer.
##
## `arrayTy` is an `array[FieldTy, 3]` where `FieldTy` itsel is an array of
## `array[WordTy, NumWords]`.
result = EcPointJac(br.asArray(arrayPtr, arrayTy))

proc asEcPointJac*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointJac =
## Constructs an elliptic curve point in Jacobian coordinates from an array pointer.
##
Expand Down Expand Up @@ -57,17 +67,48 @@ proc store*(dst: EcPointJac, src: EcPointJac) =
store(dst.getY(), src.getY())
store(dst.getZ(), src.getZ())

# Representation of a finite field point with some utilities
type EcJacArray* {.borrow: `.`.} = distinct Array

proc `=copy`(m: var EcJacArray, x: EcJacArray) {.error: "Copying an EcJacArray is not allowed. " &
"You likely want to copy the LLVM value. Use `dst.store(src)` instead.".}

proc `[]`*(a: EcJacArray, index: SomeInteger | ValueRef): EcPointJac = a.builder.asEcPointJac((distinctBase(a).getPtr(index)), a.elemTy)
proc `[]=`*(a: EcJacArray, index: SomeInteger | ValueRef, val: EcPointJac) = distinctBase(a)[index] = val.buf

proc asEcJacArray*(asy: Assembler_LLVM, cd: CurveDescriptor, a: ValueRef, num: int): EcJacArray =
## Interpret the given value `a` as an array of EC elements in Jacobian coordinates.
let ty = array_t(cd.curveTy, num)
result = EcJacArray(asy.br.asArray(a, ty))

proc initEcJacArray*(asy: Assembler_LLVM, cd: CurveDescriptor, num: int): EcJacArray =
## Initialize a new EcJacArray for `num` elements
let ty = array_t(cd.curveTy, num)
result = EcJacArray(asy.makeArray(ty))

template declEllipticJacOps*(asy: Assembler_LLVM, cd: CurveDescriptor): untyped =
## This template can be used to make operations on `Field` elements
## more convenient.
## XXX: extend to include all ops
# Setters
template setNeutral(x: EcPointJac): untyped = asy.setNeutral(cd, x.buf)

# Boolean checks
template isNeutral(res, x: EcPointJac): untyped = asy.isNeutral(cd, res, x.buf)
template isNeutral(x: EcPointJac): untyped =
var res = asy.br.alloca(asy.ctx.int1_t())
asy.isNeutral(cd, res, x.buf)
res

# Mutating assignment ops
template sum(res, x, y: EcPointJac): untyped = asy.sum(cd, res.buf, x.buf, y.buf)
template `+=`(x, y: EcPointJac): untyped = x.sum(x, y)
template mixedSum(res, x: EcPointJac, y: EcPointAff): untyped = asy.mixedSum(cd, res.buf, x.buf, y.buf)
template `+=`(x: EcPointJac, y: EcPointAff): untyped = x.mixedSum(x, y)

# Arithmetic mutations
template double(res, x: EcPointJac): untyped = asy.double(cd, res.buf, x.buf)
template double(x: EcPointJac): untyped = x.double(x)

# Conditional ops
template ccopy(x, y: EcPointJac, c): untyped = asy.ccopy(cd, x.buf, y.buf, derefBool c)

Expand Down
44 changes: 44 additions & 0 deletions constantine/math_compiler/impl_fields_ops.nim
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
constantine/platforms/bithacks, # for log2_vartime
constantine/platforms/llvm/[llvm, asm_nvidia],
./ir,
./impl_fields_globals,
Expand Down Expand Up @@ -725,3 +726,46 @@ proc scalarMul*(asy: Assembler_LLVM, fd: FieldDescriptor, a: ValueRef, b: int) =
asy.br.retVoid()

asy.callFn(name, [a])

proc getWindowAt*(asy: Assembler_LLVM, cd: CurveDescriptor, r, c, bI, wI: ValueRef) {.used.} =
## Generate an internal field `getWindowAt` function
## with signature
## void name(BaseType r, FieldType c, int bitIndex, int windowSize)
let name = cd.fd.name & "_getWindowAt"
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, c, bI, wI]),
{kHot}):
tagParameter(1, "sret")

# Operations for numbers as `ValueRef`
declNumberOps(asy, cd.fd)

let (ri, ci, bitIndex, windowSize) = llvmParams
let rA = asy.asFieldScalar(cd, ri)
let cA = asy.asFieldScalar(cd, ci)
let fd = cd.fd

# Nim values
let SlotShift = log2_vartime(fd.w.uint32)
let WordMask = fd.w - 1
let WindowMask = (1 shl windowSize) - 1 # LLVM

# LLVM values
let slot = bitIndex shr SlotShift
let word = cA[slot] # word in limbs
let pos = bitIndex and WordMask # position in the word

# This is constant-time, the branch does not depend on secret data.
llvmIf(asy): # transforms an `if` statement body into llvm conditional branches
if pos + windowSize > fd.w and slot+1 < fd.numWords:
# Read next word as well
let x = ((word shr pos) or (cA[slot+1] shl (fd.w - pos))) and WindowMask
asy.store(ri, x)
else:
let x = (word shr pos) and WindowMask
asy.store(ri, x)

asy.br.retVoid()

asy.callFn(name, [r, c, bI, wI])
96 changes: 96 additions & 0 deletions constantine/math_compiler/impl_msm_nvidia.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
constantine/platforms/llvm/[llvm, asm_nvidia],
constantine/platforms/[primitives],
./ir,
./impl_fields_globals,
./impl_fields_dispatch,
./impl_fields_ops,
./impl_curves_ops_affine,
./impl_curves_ops_jacobian,
std / typetraits # for distinctBase

## Section name used for `llvmInternalFnDef`
const SectionName = "ctt.msm_nvidia"

proc msm*(asy: Assembler_LLVM, cd: CurveDescriptor, r, coefs, points: ValueRef,
c, N: int) {.used.} =
## Inner implementation of MSM, for static dispatch over c, the bucket bit length
## This is a straightforward simple translation of BDLO12, section 4
##
## Entirely serial implementation!
##
## Important note: The coefficients given to this procedure must be in canonical
## representation instead of Montgomery representation! Thus, you cannot pass
## values of type `Fr[Curve]` directly, as they are internally stored in Montgomery
## rep. Convert to a `BigInt` using `fromField`.
let name = cd.name & "_msm_impl"
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, coefs, points]),
{kHot}):
tagParameter(1, "sret")

# Inject templates for convenient access
declFieldOps(asy, cd.fd)
declEllipticJacOps(asy, cd)
declEllipticAffOps(asy, cd)
declNumberOps(asy, cd.fd)

let (ri, coefsIn, pointsIn) = llvmParams
let rA = asy.asEcPointJac(cd, ri)
let cs = asy.asFieldScalarArray(cd, coefsIn, N) # coefficients
let Ps = asy.asEcAffArray(cd, pointsIn, N) # EC points
# Prologue
# --------
let numBuckets = 1 shl c - 1 # bucket 0 is unused
let numWindows = cd.orderBitWidth.int.ceilDiv_vartime(c)

let miniMSMs = asy.initEcJacArray(cd, numWindows)
let buckets = asy.initEcJacArray(cd, numBuckets)

# Algorithm
# ---------
var cNonZero = asy.initMutVal(cd.fd.wordTy)
asy.llvmFor w, 0, numWindows - 1, true:
# Place our points in a bucket corresponding to
# how many times their bit pattern in the current window of size c
asy.llvmFor i, 0, numBuckets - 1, true:
buckets[i].setNeutral()

# 1. Bucket accumulation. Cost: n - (2ᶜ-1) => n points in 2ᶜ-1 buckets, first point per bucket is just copied
asy.llvmFor j, 0, N-1, true:
var b = asy.initMutVal(cd.fd.wordTy)
let w0 = asy.initConstVal(0, cd.fd.wordTy)
asy.getWindowAt(cd, b.buf, cs[j].buf, asy.to(w, cd.fd.wordTy) * c, constInt(cd.fd.wordTy, c))
llvmIf(asy):
if b != w0:
buckets[b-1] += Ps[j]

var accumBuckets = asy.newEcPointJac(cd)
var miniMSM = asy.newEcPointJac(cd)
accumBuckets.store(buckets[numBuckets-1])
miniMSM.store(buckets[numBuckets-1])

asy.llvmFor k, numBuckets-2, 0, false:
accumBuckets += buckets[k] # Stores S₈ then S₈+S₇ then S₈+S₇+S₆ then ...
miniMSM += accumBuckets # Stores S₈ then [2]S₈+S₇ then [3]S₈+[2]S₇+S₆ then ...

miniMSMs[w].store(miniMSM)

rA.store(miniMSMs[numWindows-1])
asy.llvmFor w, numWindows-2, 0, false:
asy.llvmFor j, 0, c-1:
rA.double()
rA += miniMSMs[w]

asy.br.retVoid()

asy.callFn(name, [r, coefs, points])
Loading
Loading