Skip to content

Commit

Permalink
Handle vectorized float16 types using native half struct types.
Browse files Browse the repository at this point in the history
  • Loading branch information
csullivan committed Feb 12, 2024
1 parent 650dc59 commit 38693aa
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 16 deletions.
28 changes: 15 additions & 13 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n";
decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n";
decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n";
decl_stream << _cuda_vector_type_extensions;
decl_stream << "#endif\n\n";
}
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
Expand Down Expand Up @@ -248,7 +248,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
} else if (lanes <= 8) {
ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type";
// Use native vector types when working with fp8
if (enable_fp8_ && lanes <= 4) {
if (lanes <= 4) {
os << "half" << lanes;
} else {
// Emit CUDA code to access fp16 vector elements.
Expand Down Expand Up @@ -1248,9 +1248,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
std::string v = PrintExpr(op->value);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
if (op->dtype.lanes() <= 4) {
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
os << v << ", " << v;
}
} else {
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
}
}
os << ')';
return;
Expand Down Expand Up @@ -1502,15 +1509,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
PrintVecConstructor(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_half2(" << value;
if (i == t.lanes() - 1) {
os << value << ")";
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
os << value << ",";
}
return;
}
Expand Down
18 changes: 16 additions & 2 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
#define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_

#include <string>

static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
Expand Down Expand Up @@ -379,11 +381,16 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"(
)";

static constexpr const char* _cuda_vector_type_extensions = R"(
void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8) {
if (enable_fp16 || enable_fp8) {
stream << R"(
struct __align__(4) half4 {
__half x, y, z, w;
__host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}
__host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}
)";
if (enable_fp8) {
stream << R"(
__host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) {
__nv_fp8x2_e4m3 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
Expand All @@ -403,8 +410,15 @@ struct __align__(4) half4 {
result.__x =
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
})";
}
stream << R"(
};
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
return half4(x, y, z, w);
}
)";
}
}

#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
110 changes: 109 additions & 1 deletion tests/python/tir-base/test_native_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,37 @@ def add(
)


@tvm.testing.requires_cuda_compute_version(9)
def test_e4m3_packing():
native_dtype, packed_dtype = ("e4m3_float8x2", "uint32")
vector_length = 64

@T.prim_func
def add(
A: T.Buffer((vector_length,), native_dtype),
B: T.Buffer((vector_length,), packed_dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i in range(vector_length):
with T.block("C"):
v_i = T.axis.spatial(vector_length, i)
T.reads(A[v_i])
T.writes(B[v_i])
B[v_i] = T.reinterpret(packed_dtype, A[v_i])

sch = tvm.tir.Schedule(add)
block = sch.get_block("C")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 32])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")

target = "cuda"
tvm.build(sch.mod, target=target)
# TODO(csullivan): numerical check


native_dtype, promoted_dtype = tvm.testing.parameters(
("e4m3_float8", "float32"),
("e4m3_float8", "float16"),
Expand All @@ -81,7 +112,7 @@ def add(


@tvm.testing.requires_cuda_compute_version(9)
def test_e4m3x2_conversions(native_dtype, promoted_dtype):
def test_e4m3_vector_conversions(native_dtype, promoted_dtype):
vector_length = 64

@T.prim_func
Expand Down Expand Up @@ -139,5 +170,82 @@ def add(
)


bcast_length = tvm.testing.parameter(2, 4, 6, 8)


@tvm.testing.requires_cuda_compute_version(8)
def test_half_broadcast(bcast_length):
dtype = "float16"

@T.prim_func
def vector_broadcast(a: T.Buffer[(), dtype], vec: T.Buffer[(bcast_length,), dtype]):
for t in range(1):
with T.block("broadcast"):
vec[0:bcast_length] = T.broadcast(a[()], bcast_length)

sch = tvm.tir.Schedule(vector_broadcast)
block = sch.get_block("broadcast")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 1])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")

target = "cuda"
tvm.build(sch.mod, target=target)
# TODO(csullivan): numerical check


vector_length = tvm.testing.parameter(2, 4)


@tvm.testing.requires_cuda_compute_version(8)
def test_half_misaligned_vector_load(vector_length):
dtype = "float16"
vec_dtype = dtype + "x" + str(vector_length)

@T.prim_func
def vector_load(A: T.Buffer[(128,), dtype], B: T.Buffer[(32,), vec_dtype]):
for b in T.thread_binding(1, thread="blockIdx.x"):
for i in T.thread_binding(32, thread="threadIdx.x"):
vec_index = T.ramp((i + 1) * vector_length - 1, -1, vector_length)
B[i] = A[vec_index]

target = "cuda"
tvm.build(vector_load, target=target)
# TODO(csullivan): numerical check


@tvm.testing.requires_cuda_compute_version(8)
def test_half_vector_add():
dtype = "float16x4"
vector_length = 64

@T.prim_func
def add(
A: T.Buffer((vector_length,), dtype),
B: T.Buffer((vector_length,), dtype),
C: T.Buffer((vector_length,), dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i in range(vector_length):
with T.block("C"):
v_i = T.axis.spatial(vector_length, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] + B[v_i]

sch = tvm.tir.Schedule(add)
block = sch.get_block("C")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 32])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")

target = "cuda"
fadd = tvm.build(sch.mod, target=target)
# TODO(csullivan): numerical check


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 38693aa

Please sign in to comment.