Skip to content

Commit

Permalink
[CUDA] Allow async copy for fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
nirvedhmeshram committed Apr 21, 2022
1 parent 0686f2b commit 6085c4e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
5 changes: 4 additions & 1 deletion iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ static void createAsyncGroups(func::FuncOp funcOp) {
if (!read || read.getVectorType() != writeOp.getVectorType() ||
!read.isDimInBounds(0) || !read.getPermutationMap().isMinorIdentity())
return WalkResult::advance();
// Todo (nirvedhmeshram): Check if the numelement check can be relaxed to 8
// for FP16
if (read.getVectorType().getNumElements() > 4 ||
!read.getVectorType().getElementType().isF32())
!(read.getVectorType().getElementType().isF32() ||
read.getVectorType().getElementType().isF16()))
return WalkResult::advance();
copyToSharedMem.insert(writeOp);
return WalkResult::advance();
Expand Down
99 changes: 99 additions & 0 deletions iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,105 @@ hal.executable @mma_fused {

// -----

#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
hal.executable @mma_fused_fp16 {
hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> {
hal.executable.entry_point public @_large_aligned_dispatch_0 ordinal(0) layout(#hal.executable.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]>)
builtin.module {
func.func @_large_aligned_dispatch_0() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%c2048 = arith.constant 2048 : index
%c512 = arith.constant 512 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:2048x1024xf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x512xf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:2048x512xf16>
%di = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:2048x512xf16>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 1024], strides = [1, 1]
: !flow.dispatch.tensor<readonly:2048x1024xf16> -> tensor<2048x1024xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1]
: !flow.dispatch.tensor<readonly:1024x512xf16> -> tensor<1024x512xf16>
%d = flow.dispatch.tensor.load %di, offsets = [0, 0], sizes = [2048, 512], strides = [1, 1]
: !flow.dispatch.tensor<readonly:2048x512xf16> -> tensor<2048x512xf16>
%init = linalg.init_tensor [2048, 512] : tensor<2048x512xf16>
%f = linalg.fill ins(%cst : f16) outs(%init : tensor<2048x512xf16>) -> tensor<2048x512xf16>
%m = linalg.matmul ins(%3, %4 : tensor<2048x1024xf16>, tensor<1024x512xf16>) outs(%f : tensor<2048x512xf16>) -> tensor<2048x512xf16>
%init2 = linalg.init_tensor [2048, 512] : tensor<2048x512xf16>
%a = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%m, %d : tensor<2048x512xf16>, tensor<2048x512xf16>) outs(%m : tensor<2048x512xf16>) {
^bb0(%arg3: f16, %arg4: f16, %arg5: f16): // no predecessors
%19 = arith.addf %arg3, %arg4 : f16
linalg.yield %19 : f16
} -> (tensor<2048x512xf16>)
flow.dispatch.tensor.store %a, %2, offsets = [0, 0], sizes = [2048, 512], strides = [1, 1]
: tensor<2048x512xf16> -> !flow.dispatch.tensor<writeonly:2048x512xf16>
return
}
}
}
}

// CHECK-LABEL: hal.executable public @mma_fused_fp16
// CHECK: hal.executable.variant public @cuda
// CHECK-NOT: llvm.store
// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 8
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
// CHECK: nvvm.cp.async.wait.group 0
// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECK-COUNT-1: nvvm.wmma.mma
// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 8
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECK-COUNT-1: nvvm.wmma.mma
// CHECK-COUNT-4: llvm.fadd
// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>

// case with larger pipeline depth
// CHECKP-LABEL: hal.executable public @mma_fused_fp16
// CHECKP: hal.executable.variant public @cuda
// CHECKP-NOT: llvm.store
// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 8
// CHECKP: nvvm.cp.async.commit.group
// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 8
// CHECKP: nvvm.cp.async.commit.group
// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 8
// CHECKP: nvvm.cp.async.commit.group
// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 8
// CHECKP: nvvm.cp.async.commit.group
// CHECKP: llvm.br
// CHECKP: nvvm.cp.async.wait.group 3
// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECKP-COUNT-1: nvvm.wmma.mma
// CHECKP-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 8
// CHECKP: nvvm.cp.async.commit.group
// CHECKP: llvm.br
// CHECKP: nvvm.cp.async.wait.group 3
// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECKP-COUNT-1: nvvm.wmma.mma
// CHECKP: nvvm.cp.async.wait.group 2
// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECKP-COUNT-1: nvvm.wmma.mma
// CHECKP: nvvm.cp.async.wait.group 1
// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECKP-COUNT-1: nvvm.wmma.mma
// CHECKP: nvvm.cp.async.wait.group 0
// CHECKP-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECKP-COUNT-1: nvvm.wmma.mma
// CHECKP-COUNT-4: llvm.fadd
// CHECKP-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>

// -----

#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
Expand Down

0 comments on commit 6085c4e

Please sign in to comment.