Skip to content

Commit

Permalink
Add MX floating point types (f4E2M1FN, f6E2M3FN, f6E3M2FN, f8E8M0FNU)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Oct 7, 2024
1 parent 8c7d87b commit 8e9c008
Show file tree
Hide file tree
Showing 18 changed files with 3,251 additions and 116 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "00128a20eec27246719d73ba427bf821883b00b4"
LLVM_COMMIT = "3f9cabae0029bcbe88835aaa4c417ce41e584fb1"

LLVM_SHA256 = "9fff2ccb6c262f3d5e2f98c281a0b99a585daee83742e1599709ff61cfc222af"
LLVM_SHA256 = "626e1c8e491cd70ef540c403fdd43b9ff9ee50aafe181c32b14e67f715ad015e"

http_archive(
name = "llvm-raw",
Expand Down
8 changes: 6 additions & 2 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E3M4' | 'f8E4M3' | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ'
| 'f8E5M2' | 'f8E5M2FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Expand Down Expand Up @@ -284,6 +285,9 @@ values of type `tensor<T>`).
[the IEEE 754 standard](https://ieeexplore.ieee.org/document/8766229).
* `tf32` type corresponds to the [TensorFloat32 format](https://blogs.nvidia.com/blog/tensorfloat-32-precision-format/)
and has limited support in StableHLO.
* `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` MX (microscaling) types
described in
[OCP Microscaling Formats Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
* **Complex types** represent complex values that have a **real part**
and an **imaginary part** of the same **element type**. Supported complex
types are `complex<f32>` (both parts are of type `f32`) and `complex<f64>`
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def HLO_SInt : SignlessIntOfWidths<[2, 4, 8, 16, 32, 64]>;
def HLO_UInt : UnsignedIntOfWidths<[2, 4, 8, 16, 32, 64]>;
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;

def HLO_Float : AnyTypeOf<[F8E3M4, F8E4M3, F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ,
F8E5M2, F8E5M2FNUZ, F16, F32, F64, BF16]>;
def HLO_Float : AnyTypeOf<[F4E2M1FN, F6E2M3FN, F6E3M2FN, F8E3M4, F8E4M3,
F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2,
F8E5M2FNUZ, F8E8M0FNU, F16, F32, F64, BF16]>;
def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>;

def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(1, 7, 8); }
static Version getCurrentVersion() { return Version(1, 8, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
50 changes: 49 additions & 1 deletion stablehlo/dialect/VhloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ enum AttributeCode {
/// location is updated.
enum TypeCode {
// TO ADD TYPE: Add an enum value with doc string for new type.
// Next available code: 37
// Next available code: 40

/// BooleanV1Type {
/// }
Expand Down Expand Up @@ -336,6 +336,18 @@ enum TypeCode {
/// }
kWitnessV1Type = 26,

/// FloatF4E2M1FNV1Type {
/// }
kFloatF4E2M1FNV1Type = 37,

/// FloatF6E2M3FNV1Type {
/// }
kFloatF6E2M3FNV1Type = 38,

/// FloatF6E3M2FNV1Type {
/// }
kFloatF6E3M2FNV1Type = 39,

/// FloatF8E4M3FNUZV1Type {
/// }
kFloatF8E4M3FNUZV1Type = 27,
Expand All @@ -348,6 +360,10 @@ enum TypeCode {
/// }
kFloatF8E4M3B11FNUZV1Type = 29,

/// FloatF8E8M0FNUV1Type {
/// }
kFloatF8E8M0FNUV1Type = 40,

/// UniformQuantizedPerAxisV1Type {
/// flags: varint
/// storageType: Type
Expand Down Expand Up @@ -705,14 +721,18 @@ const llvm::fltSemantics &getFloatSemantics(Type type) {
if (isa<FloatBF16V1Type>(type)) return APFloat::BFloat();
if (isa<FloatF16V1Type>(type)) return APFloat::IEEEhalf();
if (isa<FloatF32V1Type>(type)) return APFloat::IEEEsingle();
if (isa<FloatF4E2M1FNV1Type>(type)) return APFloat::Float4E2M1FN();
if (isa<FloatF64V1Type>(type)) return APFloat::IEEEdouble();
if (isa<FloatF6E2M3FNV1Type>(type)) return APFloat::Float6E2M3FN();
if (isa<FloatF6E3M2FNV1Type>(type)) return APFloat::Float6E3M2FN();
if (isa<FloatF8E3M4V1Type>(type)) return APFloat::Float8E3M4();
if (isa<FloatF8E4M3FNUZV1Type>(type)) return APFloat::Float8E4M3FNUZ();
if (isa<FloatF8E4M3B11FNUZV1Type>(type)) return APFloat::Float8E4M3B11FNUZ();
if (isa<FloatF8E4M3FNV1Type>(type)) return APFloat::Float8E4M3FN();
if (isa<FloatF8E4M3V1Type>(type)) return APFloat::Float8E4M3();
if (isa<FloatF8E5M2FNUZV1Type>(type)) return APFloat::Float8E5M2FNUZ();
if (isa<FloatF8E5M2V1Type>(type)) return APFloat::Float8E5M2();
if (isa<FloatF8E8M0FNUV1Type>(type)) return APFloat::Float8E8M0FNU();
if (isa<FloatTF32V1Type>(type)) return APFloat::FloatTF32();
llvm::report_fatal_error("unsupported floating-point type");
}
Expand Down Expand Up @@ -974,8 +994,14 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return FloatF16V1Type::get(getContext());
case vhlo_encoding::kFloatF32V1Type:
return FloatF32V1Type::get(getContext());
case vhlo_encoding::kFloatF4E2M1FNV1Type:
return FloatF4E2M1FNV1Type::get(getContext());
case vhlo_encoding::kFloatF64V1Type:
return FloatF64V1Type::get(getContext());
case vhlo_encoding::kFloatF6E2M3FNV1Type:
return FloatF6E2M3FNV1Type::get(getContext());
case vhlo_encoding::kFloatF6E3M2FNV1Type:
return FloatF6E3M2FNV1Type::get(getContext());
case vhlo_encoding::kFloatF8E5M2V1Type:
return FloatF8E5M2V1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3V1Type:
Expand All @@ -990,6 +1016,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return FloatF8E4M3B11FNUZV1Type::get(getContext());
case vhlo_encoding::kFloatF8E3M4V1Type:
return FloatF8E3M4V1Type::get(getContext());
case vhlo_encoding::kFloatF8E8M0FNUV1Type:
return FloatF8E8M0FNUV1Type::get(getContext());
case vhlo_encoding::kFloatTF32V1Type:
return FloatTF32V1Type::get(getContext());
case vhlo_encoding::kFunctionV1Type:
Expand Down Expand Up @@ -1070,10 +1098,25 @@ LogicalResult VhloBytecodeInterface::writeType(
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF32V1Type), success();
})
.Case([&](FloatF4E2M1FNV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF4E2M1FNV1Type),
success();
})
.Case([&](FloatF64V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF64V1Type), success();
})
.Case([&](FloatF6E2M3FNV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF6E2M3FNV1Type),
success();
})
.Case([&](FloatF6E3M2FNV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF6E3M2FNV1Type),
success();
})
.Case([&](FloatF8E3M4V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E3M4V1Type), success();
Expand Down Expand Up @@ -1106,6 +1149,11 @@ LogicalResult VhloBytecodeInterface::writeType(
return writer.writeVarInt(vhlo_encoding::kFloatF8E5M2FNUZV1Type),
success();
})
.Case([&](FloatF8E8M0FNUV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E8M0FNUV1Type),
success();
})
.Case([&](FloatTF32V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatTF32V1Type), success();
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def VHLO_Dialect : Dialect {
1.5.0: Make collective ops (`all_reduce`, `all_gather`, `all_to_all`) variadic.
1.6.0: Add DotAlgorithm specificaiton to `dot_general`.
1.7.0: Introduce `f8E4M3` and `f8E3M4` types.
1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
24 changes: 24 additions & 0 deletions stablehlo/dialect/VhloTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ void VhloTypeConverter::addBuiltinToVhloConversions() {
[&](Float32Type type) { return FloatF32V1Type::get(type.getContext()); });
addConversion(
[&](Float64Type type) { return FloatF64V1Type::get(type.getContext()); });
addConversion([&](Float4E2M1FNType type) {
return FloatF4E2M1FNV1Type::get(type.getContext());
});
addConversion([&](Float6E2M3FNType type) {
return FloatF6E2M3FNV1Type::get(type.getContext());
});
addConversion([&](Float6E3M2FNType type) {
return FloatF6E3M2FNV1Type::get(type.getContext());
});
addConversion([&](Float8E3M4Type type) {
return FloatF8E3M4V1Type::get(type.getContext());
});
Expand All @@ -105,6 +114,9 @@ void VhloTypeConverter::addBuiltinToVhloConversions() {
addConversion([&](Float8E5M2FNUZType type) {
return FloatF8E5M2FNUZV1Type::get(type.getContext());
});
addConversion([&](Float8E8M0FNUType type) {
return FloatF8E8M0FNUV1Type::get(type.getContext());
});
addConversion([&](FloatTF32Type type) {
return FloatTF32V1Type::get(type.getContext());
});
Expand Down Expand Up @@ -182,6 +194,15 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
[&](FloatF32V1Type type) { return Float32Type::get(type.getContext()); });
addConversion(
[&](FloatF64V1Type type) { return Float64Type::get(type.getContext()); });
addConversion([&](FloatF4E2M1FNV1Type type) {
return Float4E2M1FNType::get(type.getContext());
});
addConversion([&](FloatF6E2M3FNV1Type type) {
return Float6E2M3FNType::get(type.getContext());
});
addConversion([&](FloatF6E3M2FNV1Type type) {
return Float6E3M2FNType::get(type.getContext());
});
addConversion([&](FloatF8E3M4V1Type type) {
return Float8E3M4Type::get(type.getContext());
});
Expand All @@ -203,6 +224,9 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
addConversion([&](FloatF8E5M2FNUZV1Type type) {
return Float8E5M2FNUZType::get(type.getContext());
});
addConversion([&](FloatF8E8M0FNUV1Type type) {
return Float8E8M0FNUType::get(type.getContext());
});
addConversion([&](FloatTF32V1Type type) {
return FloatTF32Type::get(type.getContext());
});
Expand Down
12 changes: 12 additions & 0 deletions stablehlo/dialect/VhloTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def VHLO_FloatF32V1 : VHLO_TypeDef<"FloatF32V1", "f32_v1", "0.9.0", "current">;
// Corresponds to the 'f64' FloatType from the StableHLO spec.
def VHLO_FloatF64V1 : VHLO_TypeDef<"FloatF64V1","f64_v1", "0.9.0", "current">;

// Corresponds to the 'f4E2M1FN' FloatType from the StableHLO spec.
def VHLO_FloatF4E2M1FNV1 : VHLO_TypeDef<"FloatF4E2M1FNV1", "f4E2M1FN_v1", "1.8.0", "current">;

// Corresponds to the 'f6E2M3FN' FloatType from the StableHLO spec.
def VHLO_FloatF6E2M3FNV1 : VHLO_TypeDef<"FloatF6E2M3FNV1", "f6E2M3FN_v1", "1.8.0", "current">;

// Corresponds to the 'f6E3M2FN' FloatType from the StableHLO spec.
def VHLO_FloatF6E3M2FNV1 : VHLO_TypeDef<"FloatF6E3M2FNV1", "f6E3M2FN_v1", "1.8.0", "current">;

// Corresponds to the 'f8E3M4' FloatType from the StableHLO spec.
def VHLO_FloatF8E3M4V1 : VHLO_TypeDef<"FloatF8E3M4V1", "f8E3M4_v1", "1.7.0", "current">;

Expand All @@ -100,6 +109,9 @@ def VHLO_FloatF8E4M3B11FNUZV1 : VHLO_TypeDef<"FloatF8E4M3B11FNUZV1", "f8E4M3B11F
// Corresponds to the 'f8E5M2FNUZ' FloatType from the StableHLO spec.
def VHLO_FloatF8E5M2FNUZV1 : VHLO_TypeDef<"FloatF8E5M2FNUZV1", "f8E5M2FNUZ_v1", "0.10.0", "current">;

// Corresponds to the 'f8E8M0FNU' FloatType from the StableHLO spec.
def VHLO_FloatF8E8M0FNUV1 : VHLO_TypeDef<"FloatF8E8M0FNUV1", "f8E8M0FNU_v1", "1.8.0", "current">;

// Corresponds to the 'tf32' FloatType from the StableHLO spec.
def VHLO_FloatTF32V1 : VHLO_TypeDef<"FloatTF32V1", "tf31_v1", "1.6.0", "current">;

Expand Down
58 changes: 15 additions & 43 deletions stablehlo/reference/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,40 +118,13 @@ Element Tensor::get(const Index &index) const {
getSizeInBytes(elementType) * flattenIndex(getShape(), index);

// Handle floating-point types.
if (elementType.isFloat8E3M4()) {
if (isSupportedFloatType(elementType) &&
cast<FloatType>(elementType).getWidth() <= 8) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E3M4(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3B11FNUZ()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3B11FNUZ(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3FN()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FN(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3FNUZ()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FNUZ(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E5M2()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E5M2(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E5M2FNUZ()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E5M2FNUZ(),
APInt(8, *elementData)));
auto floatTy = cast<FloatType>(elementType);
return Element(elementType,
APFloat(floatTy.getFloatSemantics(),
APInt(floatTy.getWidth(), *elementData)));
}
if (elementType.isF16()) {
auto elementData = reinterpret_cast<const uint16_t *>(elementPtr);
Expand Down Expand Up @@ -262,10 +235,8 @@ void Tensor::set(const Index &index, const Element &element) {
getSizeInBytes(elementType) * flattenIndex(getShape(), index);

// Handle floating-point types.
if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() ||
elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() ||
elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() ||
elementType.isFloat8E5M2FNUZ()) {
if (isSupportedFloatType(elementType) &&
cast<FloatType>(elementType).getWidth() <= 8) {
auto elementData = reinterpret_cast<uint8_t *>(elementPtr);
auto value = element.getFloatValue();
*elementData = (uint8_t)value.bitcastToAPInt().getZExtValue();
Expand Down Expand Up @@ -457,18 +428,19 @@ Tensor makeTensor(DenseElementsAttr attr) {
auto elementType = type.getElementType();

// Handle floating-point types.
if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() ||
elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() ||
elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() ||
elementType.isFloat8E5M2FNUZ()) {
if (isSupportedFloatType(elementType) &&
cast<FloatType>(elementType).getWidth() <= 8) {
auto floatValues = llvm::map_to_vector(
attr.getValues<APFloat>(), [&](APFloat value) -> uint8_t {
return value.bitcastToAPInt().getZExtValue();
});

// For f8E3M4, f8E4M3, f8E4M3FN, f8E4M3FNUZ, f8E4M3B11FNUZ, f8E5M2, and
// f8E5M2FNUZ floating-point types, we use uint8_t as their storage type
// because there are no builtin types for those.
// f8E5M2FNUZ, f8E8M0FNU floating-point types, we use uint8_t as their
// storage type because there are no builtin types for those.
// For f4E2M1FN, f6E2M3FN, and f6E3M2FN floating-point types, we still use
// uint8_t, even though the underlying types require less bits (similar
// to how ui2/ui4 types are handled).
return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign<uint8_t>(
floatValues));
}
Expand Down
10 changes: 6 additions & 4 deletions stablehlo/reference/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ bool isSupportedIntegerType(Type type) {
}

bool isSupportedFloatType(Type type) {
return type.isFloat8E3M4() || type.isFloat8E4M3B11FNUZ() ||
type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
type.isFloat8E5M2FNUZ() || type.isF16() || type.isBF16() ||
return type.isFloat4E2M1FN() || type.isFloat6E2M3FN() ||
type.isFloat6E3M2FN() || type.isFloat8E3M4() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3() ||
type.isFloat8E4M3FN() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E5M2() || type.isFloat8E5M2FNUZ() ||
type.isFloat8E8M0FNU() || type.isF16() || type.isBF16() ||
type.isF32() || type.isF64();
}

Expand Down
Loading

0 comments on commit 8e9c008

Please sign in to comment.