Skip to content

Commit

Permalink
update remaining shared memory uses
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Aug 16, 2023
1 parent 3fee9f9 commit 72cc52b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 23 deletions.
11 changes: 10 additions & 1 deletion include/kernels/block_transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ namespace quda
constexpr BlockTransposeKernel(const Arg &arg) : arg(arg) { }
static constexpr const char *filename() { return KERNEL_FILE; }

struct Dims {
static constexpr dim3 dims(dim3 block) {
block.x += 1;
block.z = 1;
return block;
}
};

/**
@brief Transpose between the two different orders of batched colorspinor fields:
- B: nVec -> spatial/N -> spin/color -> N, where N is for that in floatN
Expand All @@ -60,7 +68,8 @@ namespace quda
int parity = parity_color / Arg::nColor;
using color_spinor_t = ColorSpinor<typename Arg::real, 1, Arg::nSpin>;

SharedMemoryCache<color_spinor_t> cache({target::block_dim().x + 1, target::block_dim().y, 1});
//SharedMemoryCache<color_spinor_t> cache({target::block_dim().x + 1, target::block_dim().y, 1});
SharedMemoryCache<color_spinor_t, Dims> cache;

int x_offset = target::block_dim().x * target::block_idx().x;
int v_offset = target::block_dim().y * target::block_idx().y;
Expand Down
18 changes: 14 additions & 4 deletions include/kernels/coarse_op_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <matrix_tile.cuh>
#include <target_device.h>
#include <kernel.h>
#include <shared_memory_cache_helper.h>

namespace quda {

Expand Down Expand Up @@ -1387,14 +1388,21 @@ namespace quda {
};

template <> struct storeCoarseSharedAtomic_impl<true> {
template <typename Arg> using CacheT =
complex<storeType>[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin];
template <typename Arg> using Cache = SharedMemoryCache<CacheT<Arg>,DimsStatic<2,1,1>>;

template <typename VUV, typename Pack, typename Arg>
inline __device__ void operator()(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, int i0, int j0, int parity, const Pack &pack, const Arg &arg)
{
using real = typename Arg::Float;
using TileType = typename Arg::vuvTileType;
const int dim_index = arg.dim_index % arg.Y_atomic.geometry;
__shared__ complex<storeType> X[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin];
__shared__ complex<storeType> Y[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin];
//__shared__ complex<storeType> X[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin];
//__shared__ complex<storeType> Y[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin];
Cache<Arg> cache;
auto &X = cache.data()[0];
auto &Y = cache.data()[1];

int x_ = coarse_x_cb % arg.aggregates_per_block;
int tx = virtualThreadIdx(arg);
Expand All @@ -1416,7 +1424,8 @@ namespace quda {
}
}

__syncthreads();
//__syncthreads();
cache.sync();

#pragma unroll
for (int i = 0; i < TileType::M; i++) {
Expand Down Expand Up @@ -1445,7 +1454,8 @@ namespace quda {
}
}

__syncthreads();
//__syncthreads();
cache.sync();

if (tx < Arg::coarseSpin*Arg::coarseSpin && (parity == 0 || arg.parity_flip == 1) ) {

Expand Down
29 changes: 17 additions & 12 deletions include/kernels/color_spinor_pack.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -171,27 +171,32 @@ namespace quda {
}
};

template <bool is_native>
struct DimsPadX {
static constexpr dim3 dims(dim3 block) {
if (is_native) block.x = ((block.x + device::warp_size() - 1) / device::warp_size()) * device::warp_size();
return block;
}
};

template <> struct site_max<true> {
template <typename Arg>
struct DimsPadX {
static constexpr int Ms = spins_per_thread<true>(Arg::nSpin);
static constexpr int Mc = colors_per_thread<true>(Arg::nColor);
static constexpr int color_spin_threads = (Arg::nSpin/Ms) * (Arg::nColor/Mc);
static constexpr dim3 dims(dim3 block) {
if (Arg::is_native) block.x = ((block.x + device::warp_size() - 1) / device::warp_size()) * device::warp_size();
block.y = color_spin_threads; // state the y block since we know it at compile time
return block;
}
};

template <typename Arg> __device__ inline auto operator()(typename Arg::real thread_max, Arg &)
{
using real = typename Arg::real;
constexpr int Ms = spins_per_thread<true>(Arg::nSpin);
constexpr int Mc = colors_per_thread<true>(Arg::nColor);
constexpr int color_spin_threads = (Arg::nSpin/Ms) * (Arg::nColor/Mc);
//constexpr int Ms = spins_per_thread<true>(Arg::nSpin);
//constexpr int Mc = colors_per_thread<true>(Arg::nColor);
//constexpr int color_spin_threads = (Arg::nSpin/Ms) * (Arg::nColor/Mc);
constexpr int color_spin_threads = DimsPadX<Arg>::color_spin_threads;
//auto block = target::block_dim();
// pad the shared block size to avoid bank conflicts for native ordering
//if (Arg::is_native) block.x = ((block.x + device::warp_size() - 1) / device::warp_size()) * device::warp_size();
//block.y = color_spin_threads; // state the y block since we know it at compile time
//SharedMemoryCache<real> cache(block);
SharedMemoryCache<real, DimsPadX<Arg::is_native>> cache;
SharedMemoryCache<real, DimsPadX<Arg>> cache;
cache.save(thread_max);
cache.sync();
real this_site_max = static_cast<real>(0);
Expand Down
7 changes: 7 additions & 0 deletions include/targets/generic/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ namespace quda
}
};

template <int x, int y, int z>
struct DimsStatic {
static constexpr dim3 dims(dim3 block) {
return dim3(x,y,z);
}
};

/**
@brief Uniform helper for exposing type T, whether we are dealing
with an instance of T or some wrapper of T
Expand Down
21 changes: 15 additions & 6 deletions include/targets/generic/shared_memory_cache_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ namespace quda
// The number of elements of type atom_t that we break T into for optimal shared-memory access
static constexpr int n_element = sizeof(T) / sizeof(atom_t);

// used to avoid instantiation of load functions if unused, in case T is not a valid return type (e.g. C array)
template <typename dummy = void> using maybeT = std::conditional_t<std::is_same_v<dummy,void>,T,void>;

const dim3 block;
const int stride;

Expand All @@ -70,7 +73,8 @@ namespace quda
for (int i = 0; i < n_element; i++) smem()[i * stride + j] = tmp[i];
}

__device__ __host__ inline T load_detail(int x, int y, int z) const
template <typename dummy = void>
__device__ __host__ inline maybeT<dummy> load_detail(int x, int y, int z) const
{
atom_t tmp[n_element];
int j = (z * block.y + y) * block.x + x;
Expand Down Expand Up @@ -182,7 +186,8 @@ namespace quda
@param[in] z The z index to use
@return The value at coordinates (x,y,z)
*/
__device__ __host__ inline T load(int x = -1, int y = -1, int z = -1) const
template <typename dummy = void>
__device__ __host__ inline maybeT<dummy> load(int x = -1, int y = -1, int z = -1) const
{
auto tid = target::thread_idx();
x = (x == -1) ? tid.x : x;
Expand All @@ -196,7 +201,8 @@ namespace quda
@param[in] x The x index to use
@return The value at coordinates (x,y,z)
*/
__device__ __host__ inline T load_x(int x = -1) const
template <typename dummy = void>
__device__ __host__ inline maybeT<dummy> load_x(int x = -1) const
{
auto tid = target::thread_idx();
x = (x == -1) ? tid.x : x;
Expand All @@ -208,7 +214,8 @@ namespace quda
@param[in] y The y index to use
@return The value at coordinates (x,y,z)
*/
__device__ __host__ inline T load_y(int y = -1) const
template <typename dummy = void>
__device__ __host__ inline maybeT<dummy> load_y(int y = -1) const
{
auto tid = target::thread_idx();
y = (y == -1) ? tid.y : y;
Expand All @@ -220,7 +227,8 @@ namespace quda
@param[in] z The z index to use
@return The value at coordinates (x,y,z)
*/
__device__ __host__ inline T load_z(int z = -1) const
template <typename dummy = void>
__device__ __host__ inline maybeT<dummy> load_z(int z = -1) const
{
auto tid = target::thread_idx();
z = (z == -1) ? tid.z : z;
Expand All @@ -236,7 +244,8 @@ namespace quda
@brief Cast operator to allow cache objects to be used where T
is expected
*/
__device__ __host__ operator T() const { return load(); }
template <typename dummy = void>
__device__ __host__ operator maybeT<dummy>() const { return load(); }

/**
@brief Assignment operator to allow cache objects to be used on
Expand Down

0 comments on commit 72cc52b

Please sign in to comment.