Skip to content

Commit

Permalink
llvm: define our own addcarry/subborrow which properly optimize on x8…
Browse files Browse the repository at this point in the history
…6 (but not ARM see llvm/llvm-project#102062)
  • Loading branch information
mratsim committed Aug 13, 2024
1 parent a76cfd8 commit b415418
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 77 deletions.
72 changes: 40 additions & 32 deletions constantine/math_compiler/impl_fields_sat.nim
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,13 @@ import
#
# and while using @llvm.usub.with.overflow.i64 allows ARM64 to solve the missing optimization
# it is also missed on AMDGPU (or nvidia)
#
# And implementing them with i256 / i384 is similarly tricky
# https://github.com/llvm/llvm-project/issues/102868

const SectionName = "ctt.fields"

proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM, carry: ValueRef) =
proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array, carry: ValueRef) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand All @@ -84,30 +87,28 @@ proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM, c
##
## To be used when the final substraction can
## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256)
let t = asy.makeArray(fd.fieldTy)

let r = asy.asArray(rr, fd.fieldTy)
let M = asy.load2(fd.intBufTy, MM, "M")

let noCarry = asy.br.`not`(carry, "notcarry")
# Mask: contains 0xFFFF or 0x0000
let (_, mask) = asy.br.subborrow(fd.zero, fd.zero, carry)

# Now substract the modulus, and test a < M
# (underflow) with the last borrow.
# On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow
# if this is inline the caller https://github.com/llvm/llvm-project/issues/102868
let a_minus_M = asy.br.sub(a, M, "a_minus_M")
let borrow = asy.br.icmp(kULT, a, M, "borrow")

# Cases:
# No carry after a+b, no borrow after a-M -> return a-M
# carry after a+b, will borrow after a-M (last bit lost) -> return a-M
# carry after a+b, no borrow after a-M -> return a-M
# No carry after a+b, borrow after a-M -> return a
let ctl = asy.br.`or`(noCarry, borrow, "in_range")
let t = asy.br.select(ctl, a, a_minus_M)
# (underflow) with the last borrow
var b: ValueRef
(b, t[0]) = asy.br.subborrow(a[0], M[0], fd.zero_i1)
for i in 1 ..< fd.numWords:
(b, t[i]) = asy.br.subborrow(a[i], M[i], b)

# If it underflows here, it means that it was
# smaller than the modulus and we don't need `scratch`
(b, _) = asy.br.subborrow(mask, fd.zero, b)

for i in 0 ..< fd.numWords:
t[i] = asy.br.select(b, a[i], t[i])

asy.store(r, t)

proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM: ValueRef) =
proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand All @@ -116,18 +117,18 @@ proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM: Va
##
## To be used when the modulus does not use the full bitwidth of the storing words
## (say using 255 bits for the modulus out of 256 available in words)

let r = asy.asArray(rr, fd.fieldTy)
let M = asy.load2(fd.intBufTy, MM, "M")
let t = asy.makeArray(fd.fieldTy)

# Now substract the modulus, and test a < M
# (underflow) with the last borrow
# On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow
let a_minus_M = asy.br.sub(a, M, "a_minus_M")
let borrow = asy.br.icmp(kULT, a, M, "borrow")
var b: ValueRef
(b, t[0]) = asy.br.subborrow(a[0], M[0], fd.zero_i1)
for i in 1 ..< fd.numWords:
(b, t[i]) = asy.br.subborrow(a[i], M[i], b)

# If it underflows here a was smaller than the modulus, which is what we want
let t = asy.br.select(borrow, a, a_minus_M)
for i in 0 ..< fd.numWords:
t[i] = asy.br.select(b, a[i], t[i])

asy.store(r, t)

Expand All @@ -143,19 +144,26 @@ proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) =
asy.void_t, toTypes([r, a, b, M]),
{kHot}):

let (r, aa, bb, M) = llvmParams
tagParameter(1, "sret")

let (rr, aa, bb, MM) = llvmParams

# Pointers are opaque in LLVM now
let a = asy.load2(fd.intBufTy, aa, "a")
let b = asy.load2(fd.intBufTy, bb, "b")
let r = asy.asArray(rr, fd.fieldTy)
let a = asy.asArray(aa, fd.fieldTy)
let b = asy.asArray(bb, fd.fieldTy)
let M = asy.asArray(MM, fd.fieldTy)

let apb = asy.br.add(a, b, "a_plus_b")
let apb = asy.makeArray(fd.fieldTy)
var c: ValueRef
(c, apb[0]) = asy.br.addcarry(a[0], b[0], fd.zero_i1)
for i in 1 ..< fd.numWords:
(c, apb[i]) = asy.br.addcarry(a[i], b[i], c)

if fd.spareBits >= 1:
asy.finalSubNoOverflow(fd, r, apb, M)
else:
let carry = asy.br.icmp(kUlt, apb, b, "overflow")
asy.finalSubMayOverflow(fd, r, apb, M, carry)
asy.finalSubMayOverflow(fd, r, apb, M, c)

asy.br.retVoid()

Expand Down
37 changes: 12 additions & 25 deletions constantine/math_compiler/ir.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import
constantine/platforms/bithacks,
constantine/platforms/llvm/llvm,
std/[tables, macros]
constantine/platforms/llvm/[llvm, super_instructions],
std/tables

# ############################################################
#
Expand Down Expand Up @@ -112,6 +112,7 @@ proc new*(T: type Assembler_LLVM, backend: Backend, moduleName: cstring): Assemb
result.attrs[kInline] = result.ctx.createAttr("inlinehint")
result.attrs[kAlwaysInline] = result.ctx.createAttr("alwaysinline")
result.attrs[kNoInline] = result.ctx.createAttr("noinline")
result.attrs[kNoInline] = result.ctx.createAttr("sret")

# ############################################################
#
Expand Down Expand Up @@ -194,6 +195,10 @@ proc configureField*(ctx: ContextRef,
result.bits = modBits
result.spareBits = uint8(next_multiple_wordsize - modBits)

proc definePrimitives*(asy: Assembler_LLVM, fd: FieldDescriptor) =
asy.ctx.def_addcarry(asy.module, asy.ctx.int1_t(), fd.wordTy)
asy.ctx.def_subborrow(asy.module, asy.ctx.int1_t(), fd.wordTy)

proc wordTy*(fd: FieldDescriptor, value: SomeInteger) =
constInt(fd.wordTy, value)

Expand Down Expand Up @@ -249,11 +254,11 @@ proc makeArray*(asy: Assembler_LLVM, elemTy: TypeRef, len: uint32): Array =

proc `[]`*(a: Array, index: SomeInteger): ValueRef {.inline.}=
# First dereference the array pointer with 0, then access the `index`
let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.p, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)])
let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)])
a.builder.load2(a.elemTy, pelem)

proc `[]=`*(a: Array, index: SomeInteger, val: ValueRef) {.inline.}=
let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.p, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)])
let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)])
a.builder.store(val, pelem)

proc store*(asy: Assembler_LLVM, dst: Array, src: Array) {.inline.}=
Expand Down Expand Up @@ -385,10 +390,6 @@ proc setPublic(asy: Assembler_LLVM, fn: ValueRef) =
#
# Hopefully the compiler will remove the unnecessary lod/store/register movement, especially when inlining.

proc toTypes*[N: static int](v: array[N, ValueRef]): array[N, TypeRef] =
for i in 0 ..< v.len:
result[i] = v[i].getTypeOf()

proc wrapTypesForFnCall[N: static int](
asy: AssemblerLLVM,
paramTypes: array[N, TypeRef]
Expand Down Expand Up @@ -432,23 +433,6 @@ proc wrapTypesForFnCall[N: static int](
result.wrapped[i] = paramTypes[i]
result.src[i] = paramTypes[i]

macro unpackParams[N: static int](
br: BuilderRef,
paramsTys: tuple[wrapped, src: array[N, TypeRef]]): untyped =
## Unpack function parameters.
##
## The new function basic block MUST be setup before calling unpackParams.
##
## In the future we may automatically unwrap types.

result = nnkPar.newTree()
for i in 0 ..< N:
result.add quote do:
# let tySrc = `paramsTys`.src[`i`]
# let tyCC = `paramsTys`.wrapped[`i`]
let fn = `br`.getCurrentFunction()
fn.getParam(uint32 `i`)

proc addAttributes(asy: Assembler_LLVM, fn: ValueRef, attrs: set[AttrKind]) =
for attr in attrs:
fn.addAttribute(kAttrFnIndex, asy.attrs[attr])
Expand Down Expand Up @@ -485,6 +469,9 @@ template llvmFnDef[N: static int](
savedLoc = blck

let llvmParams {.inject.} = unpackParams(asy.br, paramsTys)
template tagParameter(idx: int, attr: string) {.inject.} =
let a = asy.ctx.createAttr(attr)
fn.addAttribute(cint idx, a)
body

if internal:
Expand Down
16 changes: 14 additions & 2 deletions constantine/platforms/abis/llvm_abi.nim
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ proc getIntTypeWidth*(ty: TypeRef): uint32 {.importc: "LLVMGetIntTypeWidth".}
proc struct_t*(
ctx: ContextRef,
elemTypes: openArray[TypeRef],
packed: LlvmBool): TypeRef {.wrapOpenArrayLenType: cuint, importc: "LLVMStructTypeInContext".}
packed = LlvmBool(false)): TypeRef {.wrapOpenArrayLenType: cuint, importc: "LLVMStructTypeInContext".}
proc array_t*(elemType: TypeRef, elemCount: uint32): TypeRef {.importc: "LLVMArrayType".}
proc vector_t*(elemType: TypeRef, elemCount: uint32): TypeRef {.importc: "LLVMVectorType".}
## Create a SIMD vector type (for SSE, AVX or Neon for example)
Expand All @@ -309,6 +309,7 @@ proc pointerType(elementType: TypeRef; addressSpace: cuint): TypeRef {.used, imp
proc getElementType*(arrayOrVectorTy: TypeRef): TypeRef {.importc: "LLVMGetElementType".}
proc getArrayLength*(arrayTy: TypeRef): uint64 {.importc: "LLVMGetArrayLength2".}
proc getNumElements*(structTy: TypeRef): cuint {.importc: "LLVMCountStructElementTypes".}
proc getVectorSize*(vecTy: TypeRef): cuint {.importc: "LLVMGetVectorSize".}

# Functions
# ------------------------------------------------------------
Expand Down Expand Up @@ -648,6 +649,8 @@ proc addGlobal*(module: ModuleRef, ty: TypeRef, name: cstring): ValueRef {.impor
proc setGlobal*(globalVar: ValueRef, constantVal: ValueRef) {.importc: "LLVMSetInitializer".}
proc setImmutable*(globalVar: ValueRef, immutable = LlvmBool(true)) {.importc: "LLVMSetGlobalConstant".}

proc getGlobalParent*(global: ValueRef): ModuleRef {.importc: "LLVMGetGlobalParent".}

proc setLinkage*(global: ValueRef, linkage: Linkage) {.importc: "LLVMSetLinkage".}
proc setVisibility*(global: ValueRef, vis: Visibility) {.importc: "LLVMSetVisibility".}
proc setAlignment*(v: ValueRef, bytes: cuint) {.importc: "LLVMSetAlignment".}
Expand Down Expand Up @@ -683,6 +686,15 @@ proc constArray*(
constantVals: openArray[ValueRef],
): ValueRef {.wrapOpenArrayLenType: cuint, importc: "LLVMConstArray".}

# Undef & Poison
# ------------------------------------------------------------
# https://llvm.org/devmtg/2020-09/slides/Lee-UndefPoison.pdf

proc poison*(ty: TypeRef): ValueRef {.importc: "LLVMGetPoison".}
proc undef*(ty: TypeRef): ValueRef {.importc: "LLVMGetUndef".}



# ############################################################
#
# IR builder
Expand Down Expand Up @@ -826,7 +838,7 @@ proc alloca*(builder: BuilderRef, ty: TypeRef, name: cstring = ""): ValueRef {.i
proc allocaArray*(builder: BuilderRef, ty: TypeRef, length: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildArrayAlloca".}

proc extractValue*(builder: BuilderRef, aggVal: ValueRef, index: uint32, name: cstring = ""): ValueRef {.importc: "LLVMBuildExtractValue".}
proc insertValue*(builder: BuilderRef, aggVal: ValueRef, eltVal: ValueRef, index: uint32, name: cstring = ""): ValueRef {.discardable, importc: "LLVMBuildInsertValue".}
proc insertValue*(builder: BuilderRef, aggVal: ValueRef, eltVal: ValueRef, index: uint32, name: cstring = ""): ValueRef {.importc: "LLVMBuildInsertValue".}

proc getElementPtr2*(
builder: BuilderRef,
Expand Down
25 changes: 25 additions & 0 deletions constantine/platforms/llvm/llvm.nim
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import constantine/platforms/abis/llvm_abi {.all.}
import std/macros
export llvm_abi

# ############################################################
Expand Down Expand Up @@ -155,6 +156,9 @@ proc getContext*(builder: BuilderRef): ContextRef =
# https://github.com/llvm/llvm-project/issues/59875
builder.getCurrentFunction().getTypeOf().getContext()

proc getCurrentModule*(builder: BuilderRef): ModuleRef =
builder.getCurrentFunction().getGlobalParent()

# Types
# ------------------------------------------------------------

Expand All @@ -181,6 +185,27 @@ proc function_t*(returnType: TypeRef, paramTypes: openArray[TypeRef]): TypeRef {
proc createAttr*(ctx: ContextRef, name: openArray[char]): AttributeRef =
ctx.toAttr(name.toAttrId())

proc toTypes*[N: static int](v: array[N, ValueRef]): array[N, TypeRef] =
for i in 0 ..< v.len:
result[i] = v[i].getTypeOf()

macro unpackParams*[N: static int](
br: BuilderRef,
paramsTys: tuple[wrapped, src: array[N, TypeRef]]): untyped =
## Unpack function parameters.
##
## The new function basic block MUST be setup before calling unpackParams.
##
## In the future we may automatically unwrap types.

result = nnkPar.newTree()
for i in 0 ..< N:
result.add quote do:
# let tySrc = `paramsTys`.src[`i`]
# let tyCC = `paramsTys`.wrapped[`i`]
let fn = `br`.getCurrentFunction()
fn.getParam(uint32 `i`)

# Values
# ------------------------------------------------------------

Expand Down
Loading

0 comments on commit b415418

Please sign in to comment.