Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][CUDA] Add native FP8 support to codegen #16548

Merged
merged 17 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,17 +398,19 @@ TVM_DLL Pass ForceNarrowIndexToInt32();
/*!
* \brief Legalize bf16 compute Ops. Add a cast to fp32
* before Ops, then add a cast back to bf16.
* \param target The target used for checking native bf16 support
* \return The pass.
*/
TVM_DLL Pass BF16ComputeLegalize();

/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
* before Ops, then add a cast back to fp8.
* \param target The target used for checking native fp8 support
* \param promote_dtype_str The data type used for type promotion, defaults to float16
* \return The pass.
*/
TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = "float16");

/*!
* \brief Legalize bf16 storage types to u16.
Expand All @@ -420,7 +422,7 @@ TVM_DLL Pass BF16StorageLegalize();
* \brief Legalize fp8 storage types to u8.
* \return The pass.
*/
TVM_DLL Pass FP8StorageLegalize();
TVM_DLL Pass FP8StorageLegalize(Target target);

/*!
* \brief Inline calls to private functions
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def callback_libdevice_path(arch):
return ""


@tvm._ffi.register_func("tvm.contrib.nvcc.get_compute_version")
def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target.
Expand Down Expand Up @@ -406,6 +407,7 @@ def have_cudagraph():
return False


@tvm._ffi.register_func("tvm.contrib.nvcc.supports_bf16")
def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not
Expand All @@ -421,6 +423,7 @@ def have_bf16(compute_version):
return False


@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp8")
def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not
Expand Down
5 changes: 3 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::TransformMmaBufferLayout());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::FP8ComputeLegalize());
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand Down Expand Up @@ -570,6 +569,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

Array<Pass> mixed_pass_list;

mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target));

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
Expand Down Expand Up @@ -619,7 +620,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target));
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
Expand Down
3 changes: 3 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
default:
LOG(FATAL) << "do not support " << dtype;
}
} else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) {
etype = llvm::Type::getInt8Ty(*ctx);
}
if (dtype.lanes() != 1) {
#if TVM_LLVM_VERSION >= 110
Expand All @@ -594,6 +596,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
return llvm::VectorType::get(etype, dtype.lanes());
#endif
} else {
ICHECK(etype != nullptr) << "No suitable llvm type found for dtype: " << dtype;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this breaks CI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @vinx13. I was trying to leave the code better than I found it, but I guess some downstream handling is relying on this behavior. I have removed the check.

return etype;
}
} // namespace codegen
Expand Down
130 changes: 94 additions & 36 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@
namespace tvm {
namespace codegen {

std::string GetFP8Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "_2";
} else if (lanes == 4) {
vec = "_4";
} else if (lanes == 8) {
vec = "_8";
} else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8";
}
if (type.code() == DataType::kE4M3Float) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kE5M2Float) {
stream << "fp8_e5" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
}
return stream.str();
}

CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }

void CodeGenCUDA::Init(bool output_ssa) {
Expand Down Expand Up @@ -121,8 +146,15 @@ std::string CodeGenCUDA::Finish() {
if (enable_fp8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
decl_stream << "#include <cuda_fp8.h>\n";
decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n";
decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n";
decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n";
decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n";
decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n";
decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n";
decl_stream << "#endif\n\n";
}
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
Expand Down Expand Up @@ -214,17 +246,23 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
if (t.is_scalar()) {
os << "half";
} else if (lanes <= 8) {
// Emit CUDA code to access fp16 vector elements.
//
// half4 is stored as uint2
//
// h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2;
ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type";
// Use native vector types when working with fp8
if (lanes <= 4) {
os << "half" << lanes;
} else {
// Emit CUDA code to access fp16 vector elements.
//
// half4 is stored as uint2
//
// h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y
//

os << "uint" << lanes / 2;
}
} else {
fail = true;
}
Expand Down Expand Up @@ -271,16 +309,9 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
if (!fail) return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail) return;
enable_fp8_ = true;
os << GetFP8Type(t);
return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
Expand Down Expand Up @@ -446,7 +477,7 @@ void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) {

void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
// Delcare the result.
// Declare the result.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(t, stream);
Expand Down Expand Up @@ -497,7 +528,14 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
if (t.lanes() == 2) {
// TODO(csullivan): Consider conditionally supporting value casting in the fp8 / vector case
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.lanes() == 4) {
os << vec << "." << access[i];
} else {
LOG(FATAL) << "Unimplemented: Only support codegen for vector half loads with lanes = {2, 4}";
}
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
Expand Down Expand Up @@ -543,8 +581,16 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
stream << "(" << value << " << " << i % 4 * 8 << ");\n";
}
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
if (t.lanes() == 2) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
} else if (t.lanes() == 4) {
stream << vec << "." << access[i] << " = " << value << ";\n";
} else {
LOG(FATAL)
<< "Unimplemented: Only support codegen for vector half stores with lanes = {2, 4}";
}

} else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
Expand Down Expand Up @@ -648,6 +694,16 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
// Emit simple C-style type conversion.
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);

if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == DataType::kE5M2Float ||
from_ty.code() == DataType::kE4M3Float || from_ty.code() == DataType::kE5M2Float) {
std::ostringstream val;
val << "(";
PrintType(target_ty, val);
val << ")(" << PrintExpr(op->value) << ")";
os << val.str();
return;
}

// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
std::string sret = name_supply_->FreshName("_");
Expand Down Expand Up @@ -1194,9 +1250,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
std::string v = PrintExpr(op->value);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
if (lanes <= 4) {
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << v << ", " << v;
}
} else {
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
}
}
os << ')';
return;
Expand Down Expand Up @@ -1448,15 +1511,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
PrintVecConstructor(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_half2(" << value;
if (i == t.lanes() - 1) {
os << value << ")";
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
os << value << ",";
}
return;
}
Expand Down
42 changes: 42 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
#define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_

#include <string>

static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
Expand Down Expand Up @@ -379,4 +381,44 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"(
)";

void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8) {
if (enable_fp16 || enable_fp8) {
stream << R"(
struct __align__(8) half4 {
__half x, y, z, w;
__host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}
__host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}
)";
if (enable_fp8) {
stream << R"(
__host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) {
__nv_fp8x2_e4m3 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
__half2 lo_half2 = static_cast<__half2>(lo_part);
__half2 hi_half2 = static_cast<__half2>(hi_part);
x = reinterpret_cast<__half*>(&lo_half2)[0];
y = reinterpret_cast<__half*>(&lo_half2)[1];
z = reinterpret_cast<__half*>(&hi_half2)[0];
w = reinterpret_cast<__half*>(&hi_half2)[1];
}
__host__ __device__ explicit operator __nv_fp8x4_e4m3() const {
__nv_fp8x4_e4m3 result;
__half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
__half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
__nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2);
result.__x =
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
})";
}
stream << R"(
};
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
return half4(x, y, z, w);
}
)";
}
}

#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
Loading
Loading