diff --git a/include/kernels/gauge_stout.cuh b/include/kernels/gauge_stout.cuh
index 4577e66fcd..712f191f83 100644
--- a/include/kernels/gauge_stout.cuh
+++ b/include/kernels/gauge_stout.cuh
@@ -135,7 +135,8 @@ namespace quda
}
Link U, Q;
- ThreadLocalCache Stap;
+ //ThreadLocalCache Stap;
+ ThreadLocalCache Stap;
ThreadLocalCache Rect; // offset by Stap type to ensure non-overlapping allocations
// This function gets stap = S_{mu,nu} i.e., the staple of length 3,
diff --git a/include/kernels/gauge_utils.cuh b/include/kernels/gauge_utils.cuh
index 48c7e6c1cc..ded8c9377a 100644
--- a/include/kernels/gauge_utils.cuh
+++ b/include/kernels/gauge_utils.cuh
@@ -19,6 +19,7 @@ namespace quda
// matrix+matrix = 18 floating-point ops
// => Total number of floating point ops per function call
// dims * (2*18 + 4*198) = dims*828
+ using computeStapleOps = thread_array;
template
__host__ __device__ inline void computeStaple(const Arg &arg, const int *x, const Int *X, const int parity, const int nu, Staple &staple, const int dir_ignore)
{
@@ -94,6 +95,7 @@ namespace quda
// matrix+matrix = 18 floating-point ops
// => Total number of floating point ops per function call
// dims * (8*18 + 28*198) = dims*5688
+ using computeStapleRectangleOps = thread_array;
template
__host__ __device__ inline void computeStapleRectangle(const Arg &arg, const int *x, const Int *X, const int parity, const int nu,
Staple &staple, Rectangle &rectangle, const int dir_ignore)
diff --git a/include/kernels/gauge_wilson_flow.cuh b/include/kernels/gauge_wilson_flow.cuh
index ae28956112..22864ce1b0 100644
--- a/include/kernels/gauge_wilson_flow.cuh
+++ b/include/kernels/gauge_wilson_flow.cuh
@@ -72,7 +72,8 @@ namespace quda
// This function gets stap = S_{mu,nu} i.e., the staple of length 3,
// and the 1x2 and 2x1 rectangles of length 5. From the following paper:
// https://arxiv.org/abs/0801.1165
- ThreadLocalCache Stap;
+ //ThreadLocalCache Stap;
+ ThreadLocalCache Stap;
ThreadLocalCache Rect; // offset by Stap type to ensure non-overlapping allocations
computeStapleRectangle(arg, x, arg.E, parity, dir, Stap, Rect, Arg::wflow_dim);
Z = arg.coeff1x1 * static_cast(Stap) + arg.coeff2x1 * static_cast(Rect);
diff --git a/include/targets/cuda/thread_array.h b/include/targets/cuda/thread_array.h
index 8237fcb87d..1c4d7f3244 100644
--- a/include/targets/cuda/thread_array.h
+++ b/include/targets/cuda/thread_array.h
@@ -11,6 +11,7 @@
namespace quda
{
template struct thread_array : array {
+ static constexpr unsigned int shared_mem_size(dim3 block) { return 0; }
};
} // namespace quda
diff --git a/include/targets/generic/thread_array.h b/include/targets/generic/thread_array.h
index 0e641a11df..d513394cfc 100644
--- a/include/targets/generic/thread_array.h
+++ b/include/targets/generic/thread_array.h
@@ -20,6 +20,8 @@ namespace quda
array &array_;
public:
+ using Smem::shared_mem_size;
+
__device__ __host__ constexpr thread_array() : array_(sharedMem()[target::thread_idx_linear<3>()])
{
array_ = array(); // call default constructor