Skip to content

Commit

Permalink
Add best effort validation for known dot algorithms (#2511)
Browse files Browse the repository at this point in the history
This is somewhat of ecosystem implementation details leaking into
StableHLO, but that seems to be safer than allowing any arbitrary
combination and enforcing that the user get some lucky combination
correct.
  • Loading branch information
GleasonK authored Aug 27, 2024
1 parent 8602e09 commit 3bb40f7
Show file tree
Hide file tree
Showing 7 changed files with 409 additions and 26 deletions.
85 changes: 85 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ limitations under the License.
#include <cstdint>
#include <functional>
#include <optional>
#include <tuple>
#include <utility>

#include "llvm/ADT/APInt.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
Expand All @@ -42,10 +45,14 @@ limitations under the License.
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/TypeID.h"

// Include order matters
#include "stablehlo/dialect/BaseAttrInterfaces.cpp.inc"

#define DEBUG_TYPE "stablehlo-base"

namespace mlir {
namespace hlo {

Expand Down Expand Up @@ -624,6 +631,84 @@ bool isSplatArray(ArrayRef<int64_t> arr, int64_t val) {
[val](int64_t x) { return x == val; });
}

namespace detail {
template <typename LHS, typename RHS, typename Accum, int64_t N>
bool match(Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType,
int64_t numPrimitiveOperations) {
return isa<LHS>(lhsPrecisionType) && isa<RHS>(rhsPrecisionType) &&
isa<Accum>(accumulationType) && numPrimitiveOperations == N;
}

FailureOr<KnownDotAlgorithm> getKnownDotAlgorithm(
Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType,
int64_t lhsComponentCount, int64_t rhsComponentCount,
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation) {
// Only support single component for now.
if (lhsComponentCount != 1 || rhsComponentCount != 1) return failure();

auto isAnyF8 = [](Type t) {
return llvm::isa<Float8E4M3FNType, Float8E5M2Type, Float8E4M3FNUZType,
Float8E4M3B11FNUZType, Float8E5M2FNUZType>(t);
};
if (isAnyF8(lhsPrecisionType) && isAnyF8(rhsPrecisionType) &&
accumulationType.isF32() && numPrimitiveOperations == 1) {
if (allowImpreciseAccumulation)
return KnownDotAlgorithm::ANY_F8_ANY_F8_F32_FAST_ACCUM;
return KnownDotAlgorithm::ANY_F8_ANY_F8_F32;
}
if (allowImpreciseAccumulation) return failure();

// TypeID doesn't define a `<` operator so cannot use in map.
// Use its name instead.
auto key = std::make_tuple(lhsPrecisionType.getAbstractType().getName(),
rhsPrecisionType.getAbstractType().getName(),
accumulationType.getAbstractType().getName(),
numPrimitiveOperations);

StringRef bf16 = BFloat16Type::name;
StringRef f16 = Float16Type::name;
StringRef f32 = Float32Type::name;
StringRef f64 = Float64Type::name;
StringRef tf32 = FloatTF32Type::name;
std::map<std::tuple<StringRef, StringRef, StringRef, int64_t>,
KnownDotAlgorithm>
knownDotAlgorithms{
{{f16, f16, f16, 1}, KnownDotAlgorithm::F16_F16_F16},
{{f16, f16, f32, 1}, KnownDotAlgorithm::F16_F16_F32},
{{bf16, bf16, bf16, 1}, KnownDotAlgorithm::BF16_BF16_BF16},
{{bf16, bf16, f32, 1}, KnownDotAlgorithm::BF16_BF16_F32},
{{bf16, bf16, f32, 3}, KnownDotAlgorithm::BF16_BF16_F32_X3},
{{bf16, bf16, f32, 6}, KnownDotAlgorithm::BF16_BF16_F32_X6},
{{tf32, tf32, f32, 1}, KnownDotAlgorithm::TF32_TF32_F32},
{{tf32, tf32, f32, 3}, KnownDotAlgorithm::TF32_TF32_F32_X3},
{{f32, f32, f32, 1}, KnownDotAlgorithm::F32_F32_F32},
{{f64, f64, f64, 1}, KnownDotAlgorithm::F64_F64_F64},
};

auto algorithm = knownDotAlgorithms.find(key);
if (algorithm != knownDotAlgorithms.end()) {
LLVM_DEBUG(llvm::dbgs()
<< "Found known dot algorithm: "
<< static_cast<int64_t>(algorithm->second) << " "
<< std::get<0>(key) << ", " << std::get<1>(key) << ", "
<< std::get<2>(key) << ", " << std::get<3>(key) << "\n");
return algorithm->second;
}
return failure();
}
} // namespace detail

// Check if the combination of a dot algorithm struct is known.
bool isKnownDotAlgorithm(Type lhsPrecisionType, Type rhsPrecisionType,
Type accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount,
int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation) {
return succeeded(detail::getKnownDotAlgorithm(
lhsPrecisionType, rhsPrecisionType, accumulationType, lhsComponentCount,
rhsComponentCount, numPrimitiveOperations, allowImpreciseAccumulation));
}

mlir::Speculation::Speculatability getShapedSpeculatability(
Operation* op, int64_t shapeCount) {
// If all inputs are static and the shape-related operands are constant
Expand Down
37 changes: 37 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,43 @@ class HloDialectInterface : public DialectInterface::Base<HloDialectInterface> {
virtual Attribute createTypeExtensions(ArrayRef<int64_t> bounds) const = 0;
};

namespace detail {

// An enum which tracks known supported dot algorithm pairs.
// Note this implementation is a detail for now and the APIs are likely to
// change once HLO broadens support for LHS/RHS components and num primitive
// operations.
//
// It is best to not rely on these values until the API solidifies.
// Instead use `isKnownDotAlgorithm`.
enum class KnownDotAlgorithm {
ANY_F8_ANY_F8_F32 = 1,
ANY_F8_ANY_F8_F32_FAST_ACCUM = 2,
F16_F16_F16 = 3,
F16_F16_F32 = 4,
BF16_BF16_BF16 = 5,
BF16_BF16_F32 = 6,
BF16_BF16_F32_X3 = 7,
BF16_BF16_F32_X6 = 8,
TF32_TF32_F32 = 9,
TF32_TF32_F32_X3 = 10,
F32_F32_F32 = 11,
F64_F64_F64 = 12,
};

FailureOr<KnownDotAlgorithm> getKnownDotAlgorithm(
Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType,
int64_t lhsComponentCount, int64_t rhsComponentCount,
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation);
} // namespace detail

// Check if the combination of a dot algorithm struct is known.
bool isKnownDotAlgorithm(Type lhsPrecisionType, Type rhsPrecisionType,
Type accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount,
int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation);

namespace bytecode {
// Helper methods for bytecode
// Enum reader and writer. Many attrs have a single enum type to serialize.
Expand Down
35 changes: 17 additions & 18 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ limitations under the License.
#include <iterator>
#include <numeric>
#include <optional>
#include <set>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -4057,24 +4059,6 @@ LogicalResult verifyDotAlgorithmAttr(
Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType,
int64_t lhsComponentCount, int64_t rhsComponentCount,
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation) {
auto isValidType = [](Type t) {
// Only support float types for now
// This can be extended as needed, as the RFC was for general support, but
// only FP hardware support exists in the ecosystem today.
return llvm::isa<FloatTF32Type, Float8E4M3FNType, Float8E5M2Type,
Float8E4M3FNUZType, Float8E4M3B11FNUZType,
Float8E5M2FNUZType, BFloat16Type, Float16Type, Float32Type,
Float64Type>(t);
};
// dot_general_i8
if (!isValidType(lhsPrecisionType))
return emitError() << "lhs precision type must be float";
// dot_general_i9
if (!isValidType(rhsPrecisionType))
return emitError() << "rhs precision type must be float";
// dot_general_i10
if (!isValidType(accumulationType))
return emitError() << "accumulation type must be float";
// dot_general_c22
if (lhsComponentCount < 1)
return emitError() << "lhs component count must be positive";
Expand All @@ -4084,6 +4068,21 @@ LogicalResult verifyDotAlgorithmAttr(
// dot_general_c24
if (numPrimitiveOperations < 1)
return emitError() << "num primitive operations must be positive";

// Best effort algorithm verification, support algorithm combinations
// known to be supported on some hardware, not necessarily the target hardware
// dot_general_i8, dot_general_i9, dot_general_i10
if (!isKnownDotAlgorithm(lhsPrecisionType, rhsPrecisionType, accumulationType,
lhsComponentCount, rhsComponentCount,
numPrimitiveOperations, allowImpreciseAccumulation))
return emitError()
<< "dot algorithm not known to be supported on any hardware: "
<< "{lhs:" << lhsPrecisionType << ", rhs:" << rhsPrecisionType
<< ", accum:" << accumulationType
<< ", lhs_components:" << lhsComponentCount
<< ", rhs_components:" << rhsComponentCount
<< ", primitive_ops:" << numPrimitiveOperations
<< ", imprecise:" << allowImpreciseAccumulation << "}";
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/interpret/dot_general.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func.func @dot_general_op_test_algorithm() {
algorithm = <
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
Expand Down
Loading

0 comments on commit 3bb40f7

Please sign in to comment.