-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add polynomial/div_by_x_minus_z.cuh.
- Loading branch information
Showing
1 changed file
with
366 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,366 @@ | ||
// Copyright Supranational LLC | ||
// Licensed under the Apache License, Version 2.0, see LICENSE for details. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#if !defined(__SPPARK_POLYNOMIAL_DIVISION_CUH__) && defined(__CUDACC__) | ||
#define __SPPARK_POLYNOMIAL_DIVISION_CUH__ | ||
|
||
#include <cassert> | ||
#include <cooperative_groups.h> | ||
#include <ff/shfl.cuh> | ||
|
||
template<class fr_t, int BSZ> __global__ __launch_bounds__(BSZ) | ||
void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) | ||
{ | ||
struct my { | ||
__device__ __forceinline__ | ||
static void madd_up(fr_t& coeff, fr_t& z_pow, uint32_t limit = WARP_SZ) | ||
{ | ||
const uint32_t laneid = threadIdx.x % WARP_SZ; | ||
|
||
__builtin_assume(limit > 1); | ||
|
||
#pragma unroll 1 | ||
for (uint32_t off = 1; off < limit; off <<= 1) { | ||
auto temp = shfl_up(coeff, off); | ||
temp = fr_t::csel(temp, z_pow, laneid != 0); | ||
z_pow *= temp; // 0th lane squares z_pow | ||
temp = coeff + z_pow; | ||
coeff = fr_t::csel(coeff, temp, laneid < off); | ||
z_pow = shfl_idx(z_pow, 0); | ||
} | ||
/* beware that resulting |z_pow| can be fed to the next madd_up() */ | ||
} | ||
|
||
__device__ __noinline__ | ||
static fr_t mult_up(fr_t z_lane, uint32_t limit = WARP_SZ) | ||
{ | ||
const uint32_t laneid = threadIdx.x % WARP_SZ; | ||
|
||
__builtin_assume(limit > 1); | ||
|
||
#pragma unroll 1 | ||
for (uint32_t off = 1; off < limit; off <<= 1) { | ||
auto temp = shfl_up(z_lane, off); | ||
temp *= z_lane; | ||
z_lane = fr_t::csel(z_lane, temp, laneid < off); | ||
} | ||
|
||
return z_lane; | ||
} | ||
}; | ||
|
||
assert(blockDim.x%WARP_SZ == 0 && gridDim.x <= blockDim.x); | ||
|
||
const uint32_t tid = threadIdx.x + blockDim.x*blockIdx.x; | ||
const uint32_t laneid = threadIdx.x % WARP_SZ; | ||
const uint32_t warpid = threadIdx.x / WARP_SZ; | ||
const uint32_t nwarps = blockDim.x / WARP_SZ; | ||
|
||
extern __shared__ int xchg_div_by_x_minus_z[]; | ||
fr_t* xchg = reinterpret_cast<decltype(xchg)>(xchg_div_by_x_minus_z); | ||
|
||
/* | ||
* Calculate ascending powers of |z| in ascending threads across | ||
* the grid. Since the kernel is invoked cooperatively, gridDim.x | ||
* would be not larger than the amount of SMs, which would be far | ||
* from the limit for this part of the implementation, 33*32+1. | ||
* ["This part" refers to the fact that a stricter limitation is | ||
* implied elsewhere, gridDim.x <= blockDim.x.] | ||
*/ | ||
fr_t z_pow = z; | ||
z_pow = my::mult_up(z_pow); | ||
fr_t z_pow_warp = z_pow; // z^(laneid+1) | ||
|
||
fr_t z_pow_block = z_pow_warp; // z^(threadIdx.x+1) | ||
z_pow = shfl_idx(z_pow, WARP_SZ-1); | ||
z_pow = my::mult_up(z_pow, nwarps); | ||
if (warpid != 0) { | ||
z_pow_block = shfl_idx(z_pow, warpid - 1); | ||
z_pow_block *= z_pow_warp; | ||
} | ||
fr_t z_top_block = shfl_idx(z_pow, nwarps - 1); | ||
|
||
fr_t z_pow_grid = z_pow_block; // z^(blockDim.x*blockIdx.x+threadIdx.x+1) | ||
if (blockIdx.x != 0) { | ||
z_pow = z_top_block; | ||
z_pow = my::mult_up(z_pow, min(WARP_SZ, gridDim.x)); | ||
z_pow_grid = shfl_idx(z_pow, (blockIdx.x - 1)%WARP_SZ); | ||
if (blockIdx.x > WARP_SZ) { | ||
z_pow = shfl_idx(z_pow, WARP_SZ - 1); | ||
z_pow = my::mult_up(z_pow, (gridDim.x + WARP_SZ - 1)/WARP_SZ); | ||
z_pow = shfl_idx(z_pow, (blockIdx.x - 1)/WARP_SZ - 1); | ||
z_pow_grid *= z_pow; | ||
} | ||
z_pow_grid *= z_pow_block; | ||
} | ||
|
||
// Calculate z^(z_top_block*(laneid+1)) and offload it to the shared | ||
// memory to alleviate register pressure. | ||
fr_t& z_pow_carry = xchg[max(blockDim.x/WARP_SZ, gridDim.x) + laneid]; | ||
if (gridDim.x > WARP_SZ && warpid == 0) | ||
z_pow_carry = my::mult_up(z_pow = z_top_block); | ||
|
||
#if 0 | ||
auto check = z^(tid+1); | ||
check -= z_pow_grid; | ||
assert(check.is_zero()); | ||
#endif | ||
|
||
/* | ||
* Given ∑cᵢ⋅xⁱ the goal is to sum up columns as following | ||
* | ||
* cf ce cd cc cb ca a9 c8 c7 c6 c5 c4 c3 c2 c1 c0 | ||
* cf ce cd cc cb ca a9 c8 c7 c6 c5 c4 c3 c2 c1 * z | ||
* cf ce cd cc cb ca a9 c8 c7 c6 c5 c4 c3 c2 * z^2 | ||
* cf ce cd cc cb ca a9 c8 c7 c6 c5 c4 c3 * z^3 | ||
* cf ce cd cc cb ca a9 c8 c7 c6 c5 c4 * z^4 | ||
* cf ce cd cc cb ca a9 c8 c7 c6 c5 * z^5 | ||
* cf ce cd cc cb ca a9 c8 c7 c6 * z^6 | ||
* cf ce cd cc cb ca a9 c8 c7 * z^7 | ||
* cf ce cd cc cb ca a9 c8 * z^8 | ||
* cf ce cd cc cb ca a9 * z^9 | ||
* cf ce cd cc cb ca * z^10 | ||
* cf ce cd cc cb * z^11 | ||
* cf ce cd cc * z^12 | ||
* cf ce cd * z^13 | ||
* cf ce * z^14 | ||
* cf * z^15 | ||
* | ||
* The first element of the output is the remainder and | ||
* the rest is the quotient. | ||
*/ | ||
class rev_ptr_t { | ||
fr_t* p; | ||
public: | ||
__device__ rev_ptr_t(fr_t* ptr, size_t len) : p(ptr + len - 1) {} | ||
__device__ fr_t& operator[](size_t i) { return *(p - i); } | ||
__device__ const fr_t& operator[](size_t i) const { return *(p - i); } | ||
}; | ||
rev_ptr_t inout{d_inout, len}; | ||
fr_t coeff, carry_over; | ||
auto __grid = cooperative_groups::this_grid(); | ||
|
||
for (size_t chunk = 0; chunk < len; chunk += blockDim.x*gridDim.x) { | ||
size_t idx = chunk + tid; | ||
|
||
if (sizeof(fr_t) <= 32) { | ||
if (idx < len) | ||
coeff = inout[idx]; | ||
|
||
my::madd_up(coeff, z_pow = z); | ||
|
||
if (laneid == WARP_SZ-1) | ||
xchg[warpid] = coeff; | ||
|
||
__syncthreads(); | ||
|
||
carry_over = xchg[laneid]; | ||
|
||
my::madd_up(carry_over, z_pow, nwarps); | ||
|
||
if (warpid != 0) { | ||
carry_over = shfl_idx(carry_over, warpid - 1); | ||
carry_over *= z_pow_warp; | ||
coeff += carry_over; | ||
} | ||
|
||
if (gridDim.x > 1) { | ||
size_t grid_idx = chunk + blockIdx.x*blockDim.x; | ||
if (threadIdx.x == blockDim.x-1 && grid_idx < len) | ||
inout[grid_idx] = coeff; | ||
|
||
__grid.sync(); | ||
__syncthreads(); | ||
|
||
if (blockIdx.x != 0) { | ||
grid_idx = chunk + threadIdx.x*blockDim.x; | ||
if (threadIdx.x < gridDim.x && grid_idx < len) | ||
carry_over = inout[grid_idx]; | ||
|
||
my::madd_up(carry_over, z_pow = z_top_block, | ||
min(WARP_SZ, gridDim.x)); | ||
|
||
if (gridDim.x > WARP_SZ) { | ||
if (laneid == WARP_SZ-1) | ||
xchg[warpid] = carry_over; | ||
|
||
__syncthreads(); | ||
|
||
fr_t temp = xchg[laneid]; | ||
|
||
my::madd_up(temp, z_pow, | ||
(gridDim.x + WARP_SZ - 1)/WARP_SZ); | ||
|
||
if (warpid != 0) { | ||
temp = shfl_idx(temp, warpid - 1); | ||
temp *= (z_pow = z_pow_carry); | ||
carry_over += temp; | ||
} | ||
} | ||
|
||
if (threadIdx.x < gridDim.x) | ||
xchg[threadIdx.x] = carry_over; | ||
|
||
__syncthreads(); | ||
|
||
carry_over = xchg[blockIdx.x-1]; | ||
carry_over *= z_pow_block; | ||
coeff += carry_over; | ||
} | ||
} | ||
|
||
if (chunk != 0) { | ||
carry_over = inout[chunk - 1]; | ||
carry_over *= z_pow_grid; | ||
coeff += carry_over; | ||
} | ||
} else { // ~14KB loop size with 256-bit field, yet unused... | ||
fr_t acc, z_pow_adjust; | ||
|
||
if (idx < len) | ||
acc = inout[idx]; | ||
|
||
z_pow = z; | ||
uint32_t limit = WARP_SZ; | ||
uint32_t adjust = 0; | ||
int pc = -1; | ||
|
||
do { | ||
my::madd_up(acc, z_pow, limit); | ||
|
||
if (adjust != 0) { | ||
acc = shfl_idx(acc, adjust - 1); | ||
tail_mul: | ||
acc *= z_pow_adjust; | ||
coeff += acc; | ||
} | ||
|
||
switch (++pc) { | ||
case 0: | ||
coeff = acc; | ||
|
||
if (laneid == WARP_SZ-1) | ||
xchg[warpid] = acc; | ||
|
||
__syncthreads(); | ||
|
||
acc = xchg[laneid]; | ||
|
||
limit = nwarps; | ||
adjust = warpid; | ||
z_pow_adjust = z_pow_warp; | ||
break; | ||
case 1: | ||
if (gridDim.x > 1) { | ||
size_t xchg_idx = chunk + blockIdx.x*blockDim.x; | ||
if (threadIdx.x == blockDim.x-1 && xchg_idx < len) | ||
inout[xchg_idx] = coeff; | ||
|
||
__grid.sync(); | ||
__syncthreads(); | ||
|
||
if (blockIdx.x != 0) { | ||
xchg_idx = chunk + threadIdx.x*blockDim.x; | ||
if (threadIdx.x < gridDim.x && xchg_idx < len) | ||
acc = inout[xchg_idx]; | ||
|
||
z_pow = z_top_block; | ||
limit = min(WARP_SZ, gridDim.x); | ||
adjust = 0; | ||
} else { | ||
goto final; | ||
} | ||
} else { | ||
goto final; | ||
} | ||
break; | ||
case 2: // blockIdx.x != 0 | ||
carry_over = coeff; | ||
coeff = acc; | ||
|
||
if (gridDim.x > WARP_SZ) { | ||
if (laneid == WARP_SZ-1) | ||
xchg[warpid] = acc; | ||
|
||
__syncthreads(); | ||
|
||
acc = xchg[laneid]; | ||
|
||
limit = (gridDim.x + WARP_SZ - 1)/WARP_SZ; | ||
adjust = warpid; | ||
z_pow_adjust = z_pow_carry; | ||
break; | ||
} // else fall through | ||
case 3: // blockIdx.x != 0 | ||
if (threadIdx.x < gridDim.x) | ||
xchg[threadIdx.x] = coeff; | ||
|
||
__syncthreads(); | ||
|
||
coeff = carry_over; | ||
acc = xchg[blockIdx.x-1]; | ||
z_pow_adjust = z_pow_block; | ||
pc = 3; | ||
goto tail_mul; | ||
case 4: | ||
final: | ||
if (chunk == 0) { | ||
pc = -1; | ||
break; | ||
} | ||
|
||
acc = inout[chunk - 1]; | ||
z_pow_adjust = z_pow_grid; | ||
pc = 4; | ||
goto tail_mul; | ||
default: | ||
pc = -1; | ||
break; | ||
} | ||
} while (pc >= 0); | ||
} | ||
|
||
if (gridDim.x > 1) { | ||
__grid.sync(); | ||
__syncthreads(); | ||
} | ||
|
||
if (idx < len) | ||
inout[idx] = coeff; | ||
} | ||
} | ||
|
||
template<class fr_t, class stream_t> | ||
void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z, | ||
const stream_t& s) | ||
{ | ||
constexpr int BSZ = sizeof(fr_t) <= 16 ? 1024 : 0; | ||
|
||
int gridDim = s.sm_count(); | ||
int blockDim = BSZ; | ||
|
||
if (BSZ == 0) { | ||
cudaFuncAttributes attr; | ||
CUDA_OK(cudaFuncGetAttributes(&attr, d_div_by_x_minus_z<fr_t, BSZ>)); | ||
blockDim = attr.maxThreadsPerBlock; | ||
} | ||
|
||
if (gridDim > blockDim) // there are no such large GPUs, not for now... | ||
gridDim = blockDim; | ||
|
||
size_t blocks = (len + blockDim - 1)/blockDim; | ||
|
||
if ((unsigned)gridDim > blocks) | ||
gridDim = (int)blocks; | ||
|
||
if (gridDim < 3) | ||
gridDim = 1; | ||
|
||
size_t sharedSz = sizeof(fr_t) * max(blockDim/WARP_SZ, gridDim); | ||
sharedSz += sizeof(fr_t) * WARP_SZ; | ||
|
||
s.launch_coop(d_div_by_x_minus_z<fr_t, BSZ>, {gridDim, blockDim, sharedSz}, | ||
d_inout, len, z); | ||
} | ||
#endif |