diff --git a/tc/core/polyhedral/cuda/codegen.cc b/tc/core/polyhedral/cuda/codegen.cc index 230623eff..452b1460a 100644 --- a/tc/core/polyhedral/cuda/codegen.cc +++ b/tc/core/polyhedral/cuda/codegen.cc @@ -153,9 +153,17 @@ void emitArgs(stringstream& ss, const Scop& scop) { void emitKernelSignature( stringstream& ss, const std::string& specializedName, - const Scop& scop) { + const Scop& scop, + const Block& block) { TC_CHECK_NE(specializedName, "") << "name not provided"; - ss << "__global__ void " << specializedName << "("; + auto b0 = block.view[0]; + b0 = b0 == 0 ? 1 : b0; + auto b1 = block.view[1]; + b1 = b1 == 0 ? 1 : b1; + auto b2 = block.view[2]; + b1 = b2 == 0 ? 1 : b2; + ss << "__global__ __launch_bounds__(" << b0 * b1 * b2 << ") void " + << specializedName << "("; emitArgs(ss, scop); ss << ") {" << endl; } @@ -753,7 +761,7 @@ string emitCudaKernel( } stringstream ss; - emitKernelSignature(ss, specializedName, scop); + emitKernelSignature(ss, specializedName, scop, mscop.numThreads); emitThreadIdInit(ss, mscop); emitTensorViews(ss, scop.halide.outputs, paramValues); emitTensorViews(ss, scop.halide.inputs, paramValues); diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index 1e3d6ec69..4a0c9971b 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -451,7 +451,7 @@ def fun(float(N, N) A) -> (O) auto res = std::get<0>(mscop->codegen(specializedName)); string expected( - R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA) { + R"RES(__global__ __launch_bounds__(1) void kernel_anon(int32 N, float32* pO, const float32* pA) { int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; float32 (*O)[N] = reinterpret_cast(pO); @@ -480,7 +480,7 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O) auto res = std::get<0>(mscop->codegen(specializedName)); string expected = - R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) { + R"RES(__global__ __launch_bounds__(1) void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) { int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; float32 (*O)[512] = reinterpret_cast(pO);