diff --git a/constantine/math_compiler/impl_fields_sat.nim b/constantine/math_compiler/impl_fields_sat.nim index 1dedfa44..a7e66545 100644 --- a/constantine/math_compiler/impl_fields_sat.nim +++ b/constantine/math_compiler/impl_fields_sat.nim @@ -72,10 +72,13 @@ import # # and while using @llvm.usub.with.overflow.i64 allows ARM64 to solve the missing optimization # it is also missed on AMDGPU (or nvidia) +# +# And implementing them with i256 / i384 is similarly tricky +# https://github.com/llvm/llvm-project/issues/102868 const SectionName = "ctt.fields" -proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM, carry: ValueRef) = +proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array, carry: ValueRef) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -84,30 +87,28 @@ proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM, c ## ## To be used when the final substraction can ## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256) + let t = asy.makeArray(fd.fieldTy) - let r = asy.asArray(rr, fd.fieldTy) - let M = asy.load2(fd.intBufTy, MM, "M") - - let noCarry = asy.br.`not`(carry, "notcarry") + # Mask: contains 0xFFFF or 0x0000 + let (_, mask) = asy.br.subborrow(fd.zero, fd.zero, carry) # Now substract the modulus, and test a < M - # (underflow) with the last borrow. - # On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow - # if this is inline the caller https://github.com/llvm/llvm-project/issues/102868 - let a_minus_M = asy.br.sub(a, M, "a_minus_M") - let borrow = asy.br.icmp(kULT, a, M, "borrow") - - # Cases: - # No carry after a+b, no borrow after a-M -> return a-M - # carry after a+b, will borrow after a-M (last bit lost) -> return a-M - # carry after a+b, no borrow after a-M -> return a-M - # No carry after a+b, borrow after a-M -> return a - let ctl = asy.br.`or`(noCarry, borrow, "in_range") - let t = asy.br.select(ctl, a, a_minus_M) + # (underflow) with the last borrow + var b: ValueRef + (b, t[0]) = asy.br.subborrow(a[0], M[0], fd.zero_i1) + for i in 1 ..< fd.numWords: + (b, t[i]) = asy.br.subborrow(a[i], M[i], b) + + # If it underflows here, it means that it was + # smaller than the modulus and we don't need `scratch` + (b, _) = asy.br.subborrow(mask, fd.zero, b) + + for i in 0 ..< fd.numWords: + t[i] = asy.br.select(b, a[i], t[i]) asy.store(r, t) -proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM: ValueRef) = +proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -116,18 +117,18 @@ proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM: Va ## ## To be used when the modulus does not use the full bitwidth of the storing words ## (say using 255 bits for the modulus out of 256 available in words) - - let r = asy.asArray(rr, fd.fieldTy) - let M = asy.load2(fd.intBufTy, MM, "M") + let t = asy.makeArray(fd.fieldTy) # Now substract the modulus, and test a < M # (underflow) with the last borrow - # On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow - let a_minus_M = asy.br.sub(a, M, "a_minus_M") - let borrow = asy.br.icmp(kULT, a, M, "borrow") + var b: ValueRef + (b, t[0]) = asy.br.subborrow(a[0], M[0], fd.zero_i1) + for i in 1 ..< fd.numWords: + (b, t[i]) = asy.br.subborrow(a[i], M[i], b) # If it underflows here a was smaller than the modulus, which is what we want - let t = asy.br.select(borrow, a, a_minus_M) + for i in 0 ..< fd.numWords: + t[i] = asy.br.select(b, a[i], t[i]) asy.store(r, t) @@ -143,19 +144,26 @@ proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) = asy.void_t, toTypes([r, a, b, M]), {kHot}): - let (r, aa, bb, M) = llvmParams + tagParameter(1, "sret") + + let (rr, aa, bb, MM) = llvmParams # Pointers are opaque in LLVM now - let a = asy.load2(fd.intBufTy, aa, "a") - let b = asy.load2(fd.intBufTy, bb, "b") + let r = asy.asArray(rr, fd.fieldTy) + let a = asy.asArray(aa, fd.fieldTy) + let b = asy.asArray(bb, fd.fieldTy) + let M = asy.asArray(MM, fd.fieldTy) - let apb = asy.br.add(a, b, "a_plus_b") + let apb = asy.makeArray(fd.fieldTy) + var c: ValueRef + (c, apb[0]) = asy.br.addcarry(a[0], b[0], fd.zero_i1) + for i in 1 ..< fd.numWords: + (c, apb[i]) = asy.br.addcarry(a[i], b[i], c) if fd.spareBits >= 1: asy.finalSubNoOverflow(fd, r, apb, M) else: - let carry = asy.br.icmp(kUlt, apb, b, "overflow") - asy.finalSubMayOverflow(fd, r, apb, M, carry) + asy.finalSubMayOverflow(fd, r, apb, M, c) asy.br.retVoid() diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index a29a38db..51acb1fa 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -8,8 +8,8 @@ import constantine/platforms/bithacks, - constantine/platforms/llvm/llvm, - std/[tables, macros] + constantine/platforms/llvm/[llvm, super_instructions], + std/tables # ############################################################ # @@ -112,6 +112,7 @@ proc new*(T: type Assembler_LLVM, backend: Backend, moduleName: cstring): Assemb result.attrs[kInline] = result.ctx.createAttr("inlinehint") result.attrs[kAlwaysInline] = result.ctx.createAttr("alwaysinline") result.attrs[kNoInline] = result.ctx.createAttr("noinline") + result.attrs[kNoInline] = result.ctx.createAttr("sret") # ############################################################ # @@ -194,6 +195,10 @@ proc configureField*(ctx: ContextRef, result.bits = modBits result.spareBits = uint8(next_multiple_wordsize - modBits) +proc definePrimitives*(asy: Assembler_LLVM, fd: FieldDescriptor) = + asy.ctx.def_addcarry(asy.module, asy.ctx.int1_t(), fd.wordTy) + asy.ctx.def_subborrow(asy.module, asy.ctx.int1_t(), fd.wordTy) + proc wordTy*(fd: FieldDescriptor, value: SomeInteger) = constInt(fd.wordTy, value) @@ -249,11 +254,11 @@ proc makeArray*(asy: Assembler_LLVM, elemTy: TypeRef, len: uint32): Array = proc `[]`*(a: Array, index: SomeInteger): ValueRef {.inline.}= # First dereference the array pointer with 0, then access the `index` - let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.p, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)]) + let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)]) a.builder.load2(a.elemTy, pelem) proc `[]=`*(a: Array, index: SomeInteger, val: ValueRef) {.inline.}= - let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.p, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)]) + let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)]) a.builder.store(val, pelem) proc store*(asy: Assembler_LLVM, dst: Array, src: Array) {.inline.}= @@ -385,10 +390,6 @@ proc setPublic(asy: Assembler_LLVM, fn: ValueRef) = # # Hopefully the compiler will remove the unnecessary lod/store/register movement, especially when inlining. -proc toTypes*[N: static int](v: array[N, ValueRef]): array[N, TypeRef] = - for i in 0 ..< v.len: - result[i] = v[i].getTypeOf() - proc wrapTypesForFnCall[N: static int]( asy: AssemblerLLVM, paramTypes: array[N, TypeRef] @@ -432,23 +433,6 @@ proc wrapTypesForFnCall[N: static int]( result.wrapped[i] = paramTypes[i] result.src[i] = paramTypes[i] -macro unpackParams[N: static int]( - br: BuilderRef, - paramsTys: tuple[wrapped, src: array[N, TypeRef]]): untyped = - ## Unpack function parameters. - ## - ## The new function basic block MUST be setup before calling unpackParams. - ## - ## In the future we may automatically unwrap types. - - result = nnkPar.newTree() - for i in 0 ..< N: - result.add quote do: - # let tySrc = `paramsTys`.src[`i`] - # let tyCC = `paramsTys`.wrapped[`i`] - let fn = `br`.getCurrentFunction() - fn.getParam(uint32 `i`) - proc addAttributes(asy: Assembler_LLVM, fn: ValueRef, attrs: set[AttrKind]) = for attr in attrs: fn.addAttribute(kAttrFnIndex, asy.attrs[attr]) @@ -485,6 +469,9 @@ template llvmFnDef[N: static int]( savedLoc = blck let llvmParams {.inject.} = unpackParams(asy.br, paramsTys) + template tagParameter(idx: int, attr: string) {.inject.} = + let a = asy.ctx.createAttr(attr) + fn.addAttribute(cint idx, a) body if internal: diff --git a/constantine/platforms/abis/llvm_abi.nim b/constantine/platforms/abis/llvm_abi.nim index 64368152..81b19b20 100644 --- a/constantine/platforms/abis/llvm_abi.nim +++ b/constantine/platforms/abis/llvm_abi.nim @@ -299,7 +299,7 @@ proc getIntTypeWidth*(ty: TypeRef): uint32 {.importc: "LLVMGetIntTypeWidth".} proc struct_t*( ctx: ContextRef, elemTypes: openArray[TypeRef], - packed: LlvmBool): TypeRef {.wrapOpenArrayLenType: cuint, importc: "LLVMStructTypeInContext".} + packed = LlvmBool(false)): TypeRef {.wrapOpenArrayLenType: cuint, importc: "LLVMStructTypeInContext".} proc array_t*(elemType: TypeRef, elemCount: uint32): TypeRef {.importc: "LLVMArrayType".} proc vector_t*(elemType: TypeRef, elemCount: uint32): TypeRef {.importc: "LLVMVectorType".} ## Create a SIMD vector type (for SSE, AVX or Neon for example) @@ -309,6 +309,7 @@ proc pointerType(elementType: TypeRef; addressSpace: cuint): TypeRef {.used, imp proc getElementType*(arrayOrVectorTy: TypeRef): TypeRef {.importc: "LLVMGetElementType".} proc getArrayLength*(arrayTy: TypeRef): uint64 {.importc: "LLVMGetArrayLength2".} proc getNumElements*(structTy: TypeRef): cuint {.importc: "LLVMCountStructElementTypes".} +proc getVectorSize*(vecTy: TypeRef): cuint {.importc: "LLVMGetVectorSize".} # Functions # ------------------------------------------------------------ @@ -648,6 +649,8 @@ proc addGlobal*(module: ModuleRef, ty: TypeRef, name: cstring): ValueRef {.impor proc setGlobal*(globalVar: ValueRef, constantVal: ValueRef) {.importc: "LLVMSetInitializer".} proc setImmutable*(globalVar: ValueRef, immutable = LlvmBool(true)) {.importc: "LLVMSetGlobalConstant".} +proc getGlobalParent*(global: ValueRef): ModuleRef {.importc: "LLVMGetGlobalParent".} + proc setLinkage*(global: ValueRef, linkage: Linkage) {.importc: "LLVMSetLinkage".} proc setVisibility*(global: ValueRef, vis: Visibility) {.importc: "LLVMSetVisibility".} proc setAlignment*(v: ValueRef, bytes: cuint) {.importc: "LLVMSetAlignment".} @@ -683,6 +686,15 @@ proc constArray*( constantVals: openArray[ValueRef], ): ValueRef {.wrapOpenArrayLenType: cuint, importc: "LLVMConstArray".} +# Undef & Poison +# ------------------------------------------------------------ +# https://llvm.org/devmtg/2020-09/slides/Lee-UndefPoison.pdf + +proc poison*(ty: TypeRef): ValueRef {.importc: "LLVMGetPoison".} +proc undef*(ty: TypeRef): ValueRef {.importc: "LLVMGetUndef".} + + + # ############################################################ # # IR builder @@ -826,7 +838,7 @@ proc alloca*(builder: BuilderRef, ty: TypeRef, name: cstring = ""): ValueRef {.i proc allocaArray*(builder: BuilderRef, ty: TypeRef, length: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildArrayAlloca".} proc extractValue*(builder: BuilderRef, aggVal: ValueRef, index: uint32, name: cstring = ""): ValueRef {.importc: "LLVMBuildExtractValue".} -proc insertValue*(builder: BuilderRef, aggVal: ValueRef, eltVal: ValueRef, index: uint32, name: cstring = ""): ValueRef {.discardable, importc: "LLVMBuildInsertValue".} +proc insertValue*(builder: BuilderRef, aggVal: ValueRef, eltVal: ValueRef, index: uint32, name: cstring = ""): ValueRef {.importc: "LLVMBuildInsertValue".} proc getElementPtr2*( builder: BuilderRef, diff --git a/constantine/platforms/llvm/llvm.nim b/constantine/platforms/llvm/llvm.nim index ccaeea96..e87e7989 100644 --- a/constantine/platforms/llvm/llvm.nim +++ b/constantine/platforms/llvm/llvm.nim @@ -7,6 +7,7 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import constantine/platforms/abis/llvm_abi {.all.} +import std/macros export llvm_abi # ############################################################ @@ -155,6 +156,9 @@ proc getContext*(builder: BuilderRef): ContextRef = # https://github.com/llvm/llvm-project/issues/59875 builder.getCurrentFunction().getTypeOf().getContext() +proc getCurrentModule*(builder: BuilderRef): ModuleRef = + builder.getCurrentFunction().getGlobalParent() + # Types # ------------------------------------------------------------ @@ -181,6 +185,27 @@ proc function_t*(returnType: TypeRef, paramTypes: openArray[TypeRef]): TypeRef { proc createAttr*(ctx: ContextRef, name: openArray[char]): AttributeRef = ctx.toAttr(name.toAttrId()) +proc toTypes*[N: static int](v: array[N, ValueRef]): array[N, TypeRef] = + for i in 0 ..< v.len: + result[i] = v[i].getTypeOf() + +macro unpackParams*[N: static int]( + br: BuilderRef, + paramsTys: tuple[wrapped, src: array[N, TypeRef]]): untyped = + ## Unpack function parameters. + ## + ## The new function basic block MUST be setup before calling unpackParams. + ## + ## In the future we may automatically unwrap types. + + result = nnkPar.newTree() + for i in 0 ..< N: + result.add quote do: + # let tySrc = `paramsTys`.src[`i`] + # let tyCC = `paramsTys`.wrapped[`i`] + let fn = `br`.getCurrentFunction() + fn.getParam(uint32 `i`) + # Values # ------------------------------------------------------------ diff --git a/constantine/platforms/llvm/super_instructions.nim b/constantine/platforms/llvm/super_instructions.nim index 7cf1e7a7..db678e95 100644 --- a/constantine/platforms/llvm/super_instructions.nim +++ b/constantine/platforms/llvm/super_instructions.nim @@ -79,31 +79,131 @@ proc hi(bld: BuilderRef, val: ValueRef, baseTy: TypeRef, oversize: uint32, prefi return hi -proc addcarry*(bld: BuilderRef, a, b, carryIn: ValueRef): tuple[carryOut, r: ValueRef] = +const SectionName = "ctt.superinstructions" + +proc getInstrName(baseName: string, ty: TypeRef): string = + var w, v: int # Wordsize and vector size + if ty.getTypeKind() == tkInteger: + w = int ty.getIntTypeWidth() + v = 1 + elif ty.getTypeKind() == tkVector: + v = int ty.getVectorSize() + w = int ty.getElementType().getIntTypeWidth() + else: + doAssert false, "Invalid input type: " & $ty + + return baseName & + (if v != 1: "_v" & $v else: "_") & + "u" & $w + +template defSuperInstruction[N: static int]( + module: ModuleRef, baseName: string, + returnType: TypeRef, + paramTypes: array[N, TypeRef], + body: untyped) = + ## Boilerplate for super instruction definition + ## Creates a magic `llvmParams` variable to tuple-destructure + ## to access the inputs + ## and `br` for building the instructions + let ty = paramTypes[0] + let name = baseName.getInstrName(ty) + + let ctx = module.getContext() + let br {.inject.} = ctx.createBuilder() + defer: br.dispose() + + var fn = module.getFunction(cstring name) + if fn.pointer.isNil(): + let fnTy = function_t(returnType, paramTypes) + fn = module.addFunction(cstring name, fnTy) + let blck = ctx.appendBasicBlock(fn) + br.positionAtEnd(blck) + + let llvmParams {.inject.} = unpackParams(br, (paramTypes, paramTypes)) + template tagParameter(idx: int, attr: string) {.inject, used.} = + let a = asy.ctx.createAttr(cstring attr) + fn.addAttribute(cint idx, a) + body + + fn.setFnCallConv(Fast) + fn.setLinkage(linkInternal) + fn.setSection(SectionName) + fn.addAttribute(kAttrFnIndex, ctx.createAttr("alwaysinline")) + +proc def_addcarry*(ctx: ContextRef, m: ModuleRef, carryTy, wordTy: TypeRef) = + ## Define (carryOut, result) <- a+b+carryIn + + let retType = ctx.struct_t([carryTy, wordTy]) + let inType = [wordTy, wordTy, carryTy] + + m.defSuperInstruction("addcarry", retType, inType): + let (a, b, carryIn) = llvmParams + + let add = br.add(a, b, name = "a_plus_b") + let carry0 = br.icmp(kULT, add, b, name = "carry0") + let cIn = br.zext(carryIn, wordTy, name = "carryIn") + let adc = br.add(cIn, add, name = "a_plus_b_plus_cIn") + let carry1 = br.icmp(kULT, adc, add, name = "carry1") + let carryOut = br.`or`(carry0, carry1, name = "carryOut") + + var ret = br.insertValue(poison(retType), adc, 1, "lo") + ret = br.insertValue(ret, carryOut, 0, "ret") + br.ret(ret) + +proc addcarry*(br: BuilderRef, a, b, carryIn: ValueRef): tuple[carryOut, r: ValueRef] = ## (cOut, result) <- a+b+cIn let ty = a.getTypeOf() + let tyC = carryIn.getTypeOf() + let name = "addcarry".getInstrName(ty) - let add = bld.add(a, b, name = "adc01_") - let carry0 = bld.icmp(kULT, add, b, name = "adc01c_") - let cIn = bld.zext(carryIn, ty, name = "adc2_") - let adc = bld.add(cIn, add, name = "adc_") - let carry1 = bld.icmp(kULT, adc, add, name = "adc012c_") - let carryOut = bld.`or`(carry0, carry1, name = "cOut_") + let fn = br.getCurrentModule().getFunction(cstring name) + doAssert not fn.pointer.isNil, "Function '" & name & "' does not exist in the module\n" - return (carryOut, adc) + let retTy = br.getContext().struct_t([tyC, ty]) + let fnTy = function_t(retTy, [ty, ty, tyC]) + let adc = br.call2(fnTy, fn, [a, b, carryIn], name = "adc") + adc.setInstrCallConv(Fast) + let lo = br.extractValue(adc, 1, name = "adcLo") + let cOut = br.extractValue(adc, 0, name = "adcC") + return (cOut, lo) -proc subborrow*(bld: BuilderRef, a, b, borrowIn: ValueRef): tuple[borrowOut, r: ValueRef] = - ## (bOut, result) <- a-b-bIn - let ty = a.getTypeOf() - let sub = bld.sub(a, b, name = "sbb01_") - let borrow0 = bld.icmp(kULT, a, b, name = "sbb01b_") - let bIn = bld.zext(borrowIn, ty, name = "sbb2_") - let sbb = bld.sub(sub, bIn, name = "sbb_") - let borrow1 = bld.icmp(kULT, sub, bIn, name = "sbb012b_") - let borrowOut = bld.`or`(borrow0, borrow1, name = "bOut_") +proc def_subborrow*(ctx: ContextRef, m: ModuleRef, borrowTy, wordTy: TypeRef) = + ## Define (borrowOut, result) <- a-b-borrowIn + + let retType = ctx.struct_t([borrowTy, wordTy]) + let inType = [wordTy, wordTy, borrowTy] + + m.defSuperInstruction("subborrow", retType, inType): + let (a, b, borrowIn) = llvmParams + + let sub = br.sub(a, b, name = "a_minus_b") + let borrow0 = br.icmp(kULT, a, b, name = "borrow0") + let bIn = br.zext(borrowIn, wordTy, name = "borrowIn") + let sbb = br.sub(sub, bIn, name = "sbb") + let borrow1 = br.icmp(kULT, sub, bIn, name = "borrow1") + let borrowOut = br.`or`(borrow0, borrow1, name = "borrowOut") - return (borrowOut, sbb) + var ret = br.insertValue(poison(retType), sbb, 1, "lo") + ret = br.insertValue(ret, borrowOut, 0, "ret") + br.ret(ret) + +proc subborrow*(br: BuilderRef, a, b, borrowIn: ValueRef): tuple[borrowOut, r: ValueRef] = + ## (cOut, result) <- a+b+cIn + let ty = a.getTypeOf() + let tyC = borrowIn.getTypeOf() + let name = "subborrow".getInstrName(ty) + + let fn = br.getCurrentModule().getFunction(cstring name) + doAssert not fn.pointer.isNil, "Function '" & name & "' does not exist in the module\n" + + let retTy = br.getContext().struct_t([tyC, ty]) + let fnTy = function_t(retTy, [ty, ty, tyC]) + let sbb = br.call2(fnTy, fn, [a, b, borrowIn], name = "sbb") + sbb.setInstrCallConv(Fast) + let lo = br.extractValue(sbb, 1, name = "sbbLo") + let bOut = br.extractValue(sbb, 0, name = "sbbB") + return (bOut, lo) proc mulExt*(bld: BuilderRef, a, b: ValueRef): tuple[hi, lo: ValueRef] = ## Extended precision multiplication diff --git a/research/codegen/x86_poc.nim b/research/codegen/x86_poc.nim index 848944ee..a460a822 100644 --- a/research/codegen/x86_poc.nim +++ b/research/codegen/x86_poc.nim @@ -46,6 +46,8 @@ proc t_field_add() = F[0], F[1], F[2], v = 1, w = 64) + asy.definePrimitives(fd) + discard asy.genFpAdd(fd) echo "========================================="