Skip to content
This repository has been archived by the owner on Apr 28, 2023. It is now read-only.

[Cuda Codegen] Emit launch bounds #526

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tc/core/polyhedral/cuda/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,17 @@ void emitArgs(stringstream& ss, const Scop& scop) {
void emitKernelSignature(
stringstream& ss,
const std::string& specializedName,
const Scop& scop) {
const Scop& scop,
const Block& block) {
TC_CHECK_NE(specializedName, "") << "name not provided";
ss << "__global__ void " << specializedName << "(";
auto b0 = block.view[0];
b0 = b0 == 0 ? 1 : b0;
auto b1 = block.view[1];
b1 = b1 == 0 ? 1 : b1;
auto b2 = block.view[2];
b1 = b2 == 0 ? 1 : b2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be b2 instead of b1.
However, I would suggest you remove this special handling of 0.

ss << "__global__ __launch_bounds__(" << b0 * b1 * b2 << ") void "
<< specializedName << "(";
emitArgs(ss, scop);
ss << ") {" << endl;
}
Expand Down Expand Up @@ -753,7 +761,7 @@ string emitCudaKernel(
}

stringstream ss;
emitKernelSignature(ss, specializedName, scop);
emitKernelSignature(ss, specializedName, scop, mscop.numThreads);
emitThreadIdInit(ss, mscop);
emitTensorViews(ss, scop.halide.outputs, paramValues);
emitTensorViews(ss, scop.halide.inputs, paramValues);
Expand Down
4 changes: 2 additions & 2 deletions test/test_cuda_mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def fun(float(N, N) A) -> (O)
auto res = std::get<0>(mscop->codegen(specializedName));

string expected(
R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA) {
R"RES(__global__ __launch_bounds__(1) void kernel_anon(int32 N, float32* pO, const float32* pA) {
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
float32 (*O)[N] = reinterpret_cast<float32 (*)[N]>(pO);
Expand Down Expand Up @@ -480,7 +480,7 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O)
auto res = std::get<0>(mscop->codegen(specializedName));

string expected =
R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) {
R"RES(__global__ __launch_bounds__(1) void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) {
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
float32 (*O)[512] = reinterpret_cast<float32 (*)[512]>(pO);
Expand Down