Skip to content

Commit

Permalink
fix overlapping shared mem
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Sep 7, 2023
1 parent 12c5c98 commit 5ce1230
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 2 deletions.
3 changes: 2 additions & 1 deletion include/kernels/gauge_stout.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ namespace quda
}

Link U, Q;
ThreadLocalCache<Link> Stap;
//ThreadLocalCache<Link> Stap;
ThreadLocalCache<Link,0,computeStapleRectangleOps> Stap;
ThreadLocalCache<Link,0,decltype(Stap)> Rect; // offset by Stap type to ensure non-overlapping allocations

// This function gets stap = S_{mu,nu} i.e., the staple of length 3,
Expand Down
2 changes: 2 additions & 0 deletions include/kernels/gauge_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace quda
// matrix+matrix = 18 floating-point ops
// => Total number of floating point ops per function call
// dims * (2*18 + 4*198) = dims*828
using computeStapleOps = thread_array<int, 4>;
template <typename Arg, typename Staple, typename Int>
__host__ __device__ inline void computeStaple(const Arg &arg, const int *x, const Int *X, const int parity, const int nu, Staple &staple, const int dir_ignore)
{
Expand Down Expand Up @@ -94,6 +95,7 @@ namespace quda
// matrix+matrix = 18 floating-point ops
// => Total number of floating point ops per function call
// dims * (8*18 + 28*198) = dims*5688
using computeStapleRectangleOps = thread_array<int, 4>;
template <typename Arg, typename Staple, typename Rectangle, typename Int>
__host__ __device__ inline void computeStapleRectangle(const Arg &arg, const int *x, const Int *X, const int parity, const int nu,
Staple &staple, Rectangle &rectangle, const int dir_ignore)
Expand Down
3 changes: 2 additions & 1 deletion include/kernels/gauge_wilson_flow.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ namespace quda
// This function gets stap = S_{mu,nu} i.e., the staple of length 3,
// and the 1x2 and 2x1 rectangles of length 5. From the following paper:
// https://arxiv.org/abs/0801.1165
ThreadLocalCache<Link> Stap;
//ThreadLocalCache<Link> Stap;
ThreadLocalCache<Link,0,computeStapleRectangleOps> Stap;
ThreadLocalCache<Link,0,decltype(Stap)> Rect; // offset by Stap type to ensure non-overlapping allocations
computeStapleRectangle(arg, x, arg.E, parity, dir, Stap, Rect, Arg::wflow_dim);
Z = arg.coeff1x1 * static_cast<const Link &>(Stap) + arg.coeff2x1 * static_cast<const Link &>(Rect);
Expand Down
1 change: 1 addition & 0 deletions include/targets/cuda/thread_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
namespace quda
{
template <typename T, int n> struct thread_array : array<T, n> {
static constexpr unsigned int shared_mem_size(dim3 block) { return 0; }
};
} // namespace quda

Expand Down
2 changes: 2 additions & 0 deletions include/targets/generic/thread_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace quda
array<T, n> &array_;

public:
using Smem::shared_mem_size;

__device__ __host__ constexpr thread_array() : array_(sharedMem()[target::thread_idx_linear<3>()])
{
array_ = array<T, n>(); // call default constructor
Expand Down

0 comments on commit 5ce1230

Please sign in to comment.