diff --git a/include/targets/hip/shared_memory_helper.h b/include/targets/hip/shared_memory_helper.h index bd8d919359..21a61cca45 100644 --- a/include/targets/hip/shared_memory_helper.h +++ b/include/targets/hip/shared_memory_helper.h @@ -14,10 +14,10 @@ namespace quda /** @brief Class which is used to allocate and access shared memory. The shared memory is treated as an array of type T, with the - number of elements given by the static member S::size(). The - offset from the beginning of the total shared memory block is - given by the static member O::shared_mem_size(block), or 0 if O - is void. + number of elements given by the call to the static member + S::size(target::block_dim()). The offset from the beginning of + the total shared memory block is given by the static member + O::shared_mem_size(target::block_dim()), or 0 if O is void. */ template class SharedMemory { @@ -56,6 +56,7 @@ namespace quda return target::dispatch(offset); } + public: static constexpr unsigned int get_offset(dim3 block) { unsigned int o = 0; @@ -63,16 +64,16 @@ namespace quda return o; } - public: static constexpr unsigned int shared_mem_size(dim3 block) { - return get_offset(block) + S::size()*sizeof(T); + return get_offset(block) + S::size(block)*sizeof(T); } /** @brief Constructor for SharedMemory object. */ - constexpr SharedMemory() : data(cache(get_offset(target::block_dim()))), size(S::size()) {} + constexpr SharedMemory() : data(cache(get_offset(target::block_dim()))), + size(S::size(target::block_dim())) {} /** @brief Subscripting operator returning a reference to element.