Skip to content

Commit

Permalink
[Transform] unsupported_dtype_legalize.cc - Only check cuda compute v…
Browse files Browse the repository at this point in the history
…ersion for fp8 support on cuda target
  • Loading branch information
csullivan committed Feb 21, 2024
1 parent 4593359 commit 54fa150
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/tir/transforms/unsupported_dtype_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,13 @@ namespace transform {

bool CheckDataTypeSupport(const Target& target, const std::string& support_func_name) {
bool has_native_support = false;
if (const PackedFunc* get_cv =
tvm::runtime::Registry::Get("tvm.contrib.nvcc.get_compute_version")) {
std::string compute_version = (*get_cv)(target);
if (const PackedFunc* check_support = tvm::runtime::Registry::Get(support_func_name)) {
has_native_support = (*check_support)(compute_version);
if (target->kind->name == "cuda") {
if (const PackedFunc* get_cv =
tvm::runtime::Registry::Get("tvm.contrib.nvcc.get_compute_version")) {
std::string compute_version = (*get_cv)(target);
if (const PackedFunc* check_support = tvm::runtime::Registry::Get(support_func_name)) {
has_native_support = (*check_support)(compute_version);
}
}
}
return has_native_support;
Expand Down

0 comments on commit 54fa150

Please sign in to comment.