From 8e9c008c3f696a790db2174cf16b8a2599d8ec6f Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Fri, 4 Oct 2024 12:05:11 +0200 Subject: [PATCH] Add MX floating point types (f4E2M1FN, f6E2M3FN, f6E3M2FN, f8E8M0FNU) --- WORKSPACE.bazel | 4 +- docs/spec.md | 8 +- stablehlo/dialect/Base.td | 5 +- stablehlo/dialect/Version.h | 2 +- stablehlo/dialect/VhloBytecode.cpp | 50 +- stablehlo/dialect/VhloDialect.td | 1 + stablehlo/dialect/VhloTypes.cpp | 24 + stablehlo/dialect/VhloTypes.td | 12 + stablehlo/reference/Tensor.cpp | 58 +- stablehlo/reference/Types.cpp | 10 +- stablehlo/tests/interpret/constant.mlir | 32 + stablehlo/tests/ops_stablehlo.mlir | 48 +- stablehlo/tests/ops_stablehlo_quantized.mlir | 106 +- stablehlo/tests/ops_stablehlo_roundtrip.mlir | 4 + .../stablehlo_legalize_to_vhlo.1_8_0.mlir | 2936 +++++++++++++++++ .../stablehlo_legalize_to_vhlo.1_8_0.mlir.bc | Bin 0 -> 19438 bytes .../vhlo/stablehlo_legalize_to_vhlo.mlir | 32 + ...lo_to_version_downgrade_invalid.1_7_0.mlir | 35 + 18 files changed, 3251 insertions(+), 116 deletions(-) create mode 100644 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir create mode 100644 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir.bc create mode 100644 stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_7_0.mlir diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 29ef986ad0f..751a92ff7ce 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "00128a20eec27246719d73ba427bf821883b00b4" +LLVM_COMMIT = "3f9cabae0029bcbe88835aaa4c417ce41e584fb1" -LLVM_SHA256 = "9fff2ccb6c262f3d5e2f98c281a0b99a585daee83742e1599709ff61cfc222af" +LLVM_SHA256 = "626e1c8e491cd70ef540c403fdd43b9ff9ee50aafe181c32b14e67f715ad015e" http_archive( name = "llvm-raw", diff --git a/docs/spec.md b/docs/spec.md index e6f94f3e9a8..48760dbd962 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -245,8 +245,9 @@ BooleanType ::= 'i1' IntegerType ::= SignedIntegerType | UnsignedIntegerType SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64' UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64' -FloatType ::= 'f8E3M4' | 'f8E4M3' | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' - | 'f8E5M2' | 'f8E5M2FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64' +FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3' + | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2' + | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64' TensorFloat32 ::= 'tf32' ComplexType ::= 'complex' '<' ComplexElementType '>' ComplexElementType ::= 'f32' | 'f64' @@ -284,6 +285,9 @@ values of type `tensor`). [the IEEE 754 standard](https://ieeexplore.ieee.org/document/8766229). * `tf32` type corresponds to the [TensorFloat32 format](https://blogs.nvidia.com/blog/tensorfloat-32-precision-format/) and has limited support in StableHLO. + * `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` MX (microscaling) types + described in + [OCP Microscaling Formats Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). * **Complex types** represent complex values that have a **real part** and an **imaginary part** of the same **element type**. Supported complex types are `complex` (both parts are of type `f32`) and `complex` diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 296b118c4ad..b995fcda31a 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -42,8 +42,9 @@ def HLO_SInt : SignlessIntOfWidths<[2, 4, 8, 16, 32, 64]>; def HLO_UInt : UnsignedIntOfWidths<[2, 4, 8, 16, 32, 64]>; def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; -def HLO_Float : AnyTypeOf<[F8E3M4, F8E4M3, F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ, - F8E5M2, F8E5M2FNUZ, F16, F32, F64, BF16]>; +def HLO_Float : AnyTypeOf<[F4E2M1FN, F6E2M3FN, F6E3M2FN, F8E3M4, F8E4M3, + F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2, + F8E5M2FNUZ, F8E8M0FNU, F16, F32, F64, BF16]>; def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>; def HLO_Complex : Complex>; diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h index 4cab6b9fed6..ca44d158a8a 100644 --- a/stablehlo/dialect/Version.h +++ b/stablehlo/dialect/Version.h @@ -38,7 +38,7 @@ class Version { static FailureOr fromString(llvm::StringRef versionRef); /// Return a Version representing the current VHLO dialect version. - static Version getCurrentVersion() { return Version(1, 7, 8); } + static Version getCurrentVersion() { return Version(1, 8, 0); } /// Return a Version representing the minimum supported VHLO dialect version. static Version getMinimumVersion() { return Version(0, 9, 0); } diff --git a/stablehlo/dialect/VhloBytecode.cpp b/stablehlo/dialect/VhloBytecode.cpp index d49615f5514..601463753db 100644 --- a/stablehlo/dialect/VhloBytecode.cpp +++ b/stablehlo/dialect/VhloBytecode.cpp @@ -189,7 +189,7 @@ enum AttributeCode { /// location is updated. enum TypeCode { // TO ADD TYPE: Add an enum value with doc string for new type. - // Next available code: 37 + // Next available code: 40 /// BooleanV1Type { /// } @@ -336,6 +336,18 @@ enum TypeCode { /// } kWitnessV1Type = 26, + /// FloatF4E2M1FNV1Type { + /// } + kFloatF4E2M1FNV1Type = 37, + + /// FloatF6E2M3FNV1Type { + /// } + kFloatF6E2M3FNV1Type = 38, + + /// FloatF6E3M2FNV1Type { + /// } + kFloatF6E3M2FNV1Type = 39, + /// FloatF8E4M3FNUZV1Type { /// } kFloatF8E4M3FNUZV1Type = 27, @@ -348,6 +360,10 @@ enum TypeCode { /// } kFloatF8E4M3B11FNUZV1Type = 29, + /// FloatF8E8M0FNUV1Type { + /// } + kFloatF8E8M0FNUV1Type = 40, + /// UniformQuantizedPerAxisV1Type { /// flags: varint /// storageType: Type @@ -705,7 +721,10 @@ const llvm::fltSemantics &getFloatSemantics(Type type) { if (isa(type)) return APFloat::BFloat(); if (isa(type)) return APFloat::IEEEhalf(); if (isa(type)) return APFloat::IEEEsingle(); + if (isa(type)) return APFloat::Float4E2M1FN(); if (isa(type)) return APFloat::IEEEdouble(); + if (isa(type)) return APFloat::Float6E2M3FN(); + if (isa(type)) return APFloat::Float6E3M2FN(); if (isa(type)) return APFloat::Float8E3M4(); if (isa(type)) return APFloat::Float8E4M3FNUZ(); if (isa(type)) return APFloat::Float8E4M3B11FNUZ(); @@ -713,6 +732,7 @@ const llvm::fltSemantics &getFloatSemantics(Type type) { if (isa(type)) return APFloat::Float8E4M3(); if (isa(type)) return APFloat::Float8E5M2FNUZ(); if (isa(type)) return APFloat::Float8E5M2(); + if (isa(type)) return APFloat::Float8E8M0FNU(); if (isa(type)) return APFloat::FloatTF32(); llvm::report_fatal_error("unsupported floating-point type"); } @@ -974,8 +994,14 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const { return FloatF16V1Type::get(getContext()); case vhlo_encoding::kFloatF32V1Type: return FloatF32V1Type::get(getContext()); + case vhlo_encoding::kFloatF4E2M1FNV1Type: + return FloatF4E2M1FNV1Type::get(getContext()); case vhlo_encoding::kFloatF64V1Type: return FloatF64V1Type::get(getContext()); + case vhlo_encoding::kFloatF6E2M3FNV1Type: + return FloatF6E2M3FNV1Type::get(getContext()); + case vhlo_encoding::kFloatF6E3M2FNV1Type: + return FloatF6E3M2FNV1Type::get(getContext()); case vhlo_encoding::kFloatF8E5M2V1Type: return FloatF8E5M2V1Type::get(getContext()); case vhlo_encoding::kFloatF8E4M3V1Type: @@ -990,6 +1016,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const { return FloatF8E4M3B11FNUZV1Type::get(getContext()); case vhlo_encoding::kFloatF8E3M4V1Type: return FloatF8E3M4V1Type::get(getContext()); + case vhlo_encoding::kFloatF8E8M0FNUV1Type: + return FloatF8E8M0FNUV1Type::get(getContext()); case vhlo_encoding::kFloatTF32V1Type: return FloatTF32V1Type::get(getContext()); case vhlo_encoding::kFunctionV1Type: @@ -1070,10 +1098,25 @@ LogicalResult VhloBytecodeInterface::writeType( LOG_WRITE_CALL; return writer.writeVarInt(vhlo_encoding::kFloatF32V1Type), success(); }) + .Case([&](FloatF4E2M1FNV1Type) { + LOG_WRITE_CALL; + return writer.writeVarInt(vhlo_encoding::kFloatF4E2M1FNV1Type), + success(); + }) .Case([&](FloatF64V1Type) { LOG_WRITE_CALL; return writer.writeVarInt(vhlo_encoding::kFloatF64V1Type), success(); }) + .Case([&](FloatF6E2M3FNV1Type) { + LOG_WRITE_CALL; + return writer.writeVarInt(vhlo_encoding::kFloatF6E2M3FNV1Type), + success(); + }) + .Case([&](FloatF6E3M2FNV1Type) { + LOG_WRITE_CALL; + return writer.writeVarInt(vhlo_encoding::kFloatF6E3M2FNV1Type), + success(); + }) .Case([&](FloatF8E3M4V1Type) { LOG_WRITE_CALL; return writer.writeVarInt(vhlo_encoding::kFloatF8E3M4V1Type), success(); @@ -1106,6 +1149,11 @@ LogicalResult VhloBytecodeInterface::writeType( return writer.writeVarInt(vhlo_encoding::kFloatF8E5M2FNUZV1Type), success(); }) + .Case([&](FloatF8E8M0FNUV1Type) { + LOG_WRITE_CALL; + return writer.writeVarInt(vhlo_encoding::kFloatF8E8M0FNUV1Type), + success(); + }) .Case([&](FloatTF32V1Type) { LOG_WRITE_CALL; return writer.writeVarInt(vhlo_encoding::kFloatTF32V1Type), success(); diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index ab295e11b19..c7a05727908 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -46,6 +46,7 @@ def VHLO_Dialect : Dialect { 1.5.0: Make collective ops (`all_reduce`, `all_gather`, `all_to_all`) variadic. 1.6.0: Add DotAlgorithm specificaiton to `dot_general`. 1.7.0: Introduce `f8E4M3` and `f8E3M4` types. + 1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types. }]; let useDefaultAttributePrinterParser = 0; diff --git a/stablehlo/dialect/VhloTypes.cpp b/stablehlo/dialect/VhloTypes.cpp index 251c163cf64..b2926528c1c 100644 --- a/stablehlo/dialect/VhloTypes.cpp +++ b/stablehlo/dialect/VhloTypes.cpp @@ -84,6 +84,15 @@ void VhloTypeConverter::addBuiltinToVhloConversions() { [&](Float32Type type) { return FloatF32V1Type::get(type.getContext()); }); addConversion( [&](Float64Type type) { return FloatF64V1Type::get(type.getContext()); }); + addConversion([&](Float4E2M1FNType type) { + return FloatF4E2M1FNV1Type::get(type.getContext()); + }); + addConversion([&](Float6E2M3FNType type) { + return FloatF6E2M3FNV1Type::get(type.getContext()); + }); + addConversion([&](Float6E3M2FNType type) { + return FloatF6E3M2FNV1Type::get(type.getContext()); + }); addConversion([&](Float8E3M4Type type) { return FloatF8E3M4V1Type::get(type.getContext()); }); @@ -105,6 +114,9 @@ void VhloTypeConverter::addBuiltinToVhloConversions() { addConversion([&](Float8E5M2FNUZType type) { return FloatF8E5M2FNUZV1Type::get(type.getContext()); }); + addConversion([&](Float8E8M0FNUType type) { + return FloatF8E8M0FNUV1Type::get(type.getContext()); + }); addConversion([&](FloatTF32Type type) { return FloatTF32V1Type::get(type.getContext()); }); @@ -182,6 +194,15 @@ void VhloTypeConverter::addVhloToBuiltinConversions() { [&](FloatF32V1Type type) { return Float32Type::get(type.getContext()); }); addConversion( [&](FloatF64V1Type type) { return Float64Type::get(type.getContext()); }); + addConversion([&](FloatF4E2M1FNV1Type type) { + return Float4E2M1FNType::get(type.getContext()); + }); + addConversion([&](FloatF6E2M3FNV1Type type) { + return Float6E2M3FNType::get(type.getContext()); + }); + addConversion([&](FloatF6E3M2FNV1Type type) { + return Float6E3M2FNType::get(type.getContext()); + }); addConversion([&](FloatF8E3M4V1Type type) { return Float8E3M4Type::get(type.getContext()); }); @@ -203,6 +224,9 @@ void VhloTypeConverter::addVhloToBuiltinConversions() { addConversion([&](FloatF8E5M2FNUZV1Type type) { return Float8E5M2FNUZType::get(type.getContext()); }); + addConversion([&](FloatF8E8M0FNUV1Type type) { + return Float8E8M0FNUType::get(type.getContext()); + }); addConversion([&](FloatTF32V1Type type) { return FloatTF32Type::get(type.getContext()); }); diff --git a/stablehlo/dialect/VhloTypes.td b/stablehlo/dialect/VhloTypes.td index b25c86f7d92..8414f0a78f8 100644 --- a/stablehlo/dialect/VhloTypes.td +++ b/stablehlo/dialect/VhloTypes.td @@ -79,6 +79,15 @@ def VHLO_FloatF32V1 : VHLO_TypeDef<"FloatF32V1", "f32_v1", "0.9.0", "current">; // Corresponds to the 'f64' FloatType from the StableHLO spec. def VHLO_FloatF64V1 : VHLO_TypeDef<"FloatF64V1","f64_v1", "0.9.0", "current">; +// Corresponds to the 'f4E2M1FN' FloatType from the StableHLO spec. +def VHLO_FloatF4E2M1FNV1 : VHLO_TypeDef<"FloatF4E2M1FNV1", "f4E2M1FN_v1", "1.8.0", "current">; + +// Corresponds to the 'f6E2M3FN' FloatType from the StableHLO spec. +def VHLO_FloatF6E2M3FNV1 : VHLO_TypeDef<"FloatF6E2M3FNV1", "f6E2M3FN_v1", "1.8.0", "current">; + +// Corresponds to the 'f6E3M2FN' FloatType from the StableHLO spec. +def VHLO_FloatF6E3M2FNV1 : VHLO_TypeDef<"FloatF6E3M2FNV1", "f6E3M2FN_v1", "1.8.0", "current">; + // Corresponds to the 'f8E3M4' FloatType from the StableHLO spec. def VHLO_FloatF8E3M4V1 : VHLO_TypeDef<"FloatF8E3M4V1", "f8E3M4_v1", "1.7.0", "current">; @@ -100,6 +109,9 @@ def VHLO_FloatF8E4M3B11FNUZV1 : VHLO_TypeDef<"FloatF8E4M3B11FNUZV1", "f8E4M3B11F // Corresponds to the 'f8E5M2FNUZ' FloatType from the StableHLO spec. def VHLO_FloatF8E5M2FNUZV1 : VHLO_TypeDef<"FloatF8E5M2FNUZV1", "f8E5M2FNUZ_v1", "0.10.0", "current">; +// Corresponds to the 'f8E8M0FNU' FloatType from the StableHLO spec. +def VHLO_FloatF8E8M0FNUV1 : VHLO_TypeDef<"FloatF8E8M0FNUV1", "f8E8M0FNU_v1", "1.8.0", "current">; + // Corresponds to the 'tf32' FloatType from the StableHLO spec. def VHLO_FloatTF32V1 : VHLO_TypeDef<"FloatTF32V1", "tf31_v1", "1.6.0", "current">; diff --git a/stablehlo/reference/Tensor.cpp b/stablehlo/reference/Tensor.cpp index 97ac92e3a36..600caa6d8b8 100644 --- a/stablehlo/reference/Tensor.cpp +++ b/stablehlo/reference/Tensor.cpp @@ -118,40 +118,13 @@ Element Tensor::get(const Index &index) const { getSizeInBytes(elementType) * flattenIndex(getShape(), index); // Handle floating-point types. - if (elementType.isFloat8E3M4()) { + if (isSupportedFloatType(elementType) && + cast(elementType).getWidth() <= 8) { auto elementData = reinterpret_cast(elementPtr); - return Element(elementType, APFloat(llvm::APFloatBase::Float8E3M4(), - APInt(8, *elementData))); - } - if (elementType.isFloat8E4M3B11FNUZ()) { - auto elementData = reinterpret_cast(elementPtr); - return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3B11FNUZ(), - APInt(8, *elementData))); - } - if (elementType.isFloat8E4M3()) { - auto elementData = reinterpret_cast(elementPtr); - return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3(), - APInt(8, *elementData))); - } - if (elementType.isFloat8E4M3FN()) { - auto elementData = reinterpret_cast(elementPtr); - return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FN(), - APInt(8, *elementData))); - } - if (elementType.isFloat8E4M3FNUZ()) { - auto elementData = reinterpret_cast(elementPtr); - return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FNUZ(), - APInt(8, *elementData))); - } - if (elementType.isFloat8E5M2()) { - auto elementData = reinterpret_cast(elementPtr); - return Element(elementType, APFloat(llvm::APFloatBase::Float8E5M2(), - APInt(8, *elementData))); - } - if (elementType.isFloat8E5M2FNUZ()) { - auto elementData = reinterpret_cast(elementPtr); - return Element(elementType, APFloat(llvm::APFloatBase::Float8E5M2FNUZ(), - APInt(8, *elementData))); + auto floatTy = cast(elementType); + return Element(elementType, + APFloat(floatTy.getFloatSemantics(), + APInt(floatTy.getWidth(), *elementData))); } if (elementType.isF16()) { auto elementData = reinterpret_cast(elementPtr); @@ -262,10 +235,8 @@ void Tensor::set(const Index &index, const Element &element) { getSizeInBytes(elementType) * flattenIndex(getShape(), index); // Handle floating-point types. - if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() || - elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() || - elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() || - elementType.isFloat8E5M2FNUZ()) { + if (isSupportedFloatType(elementType) && + cast(elementType).getWidth() <= 8) { auto elementData = reinterpret_cast(elementPtr); auto value = element.getFloatValue(); *elementData = (uint8_t)value.bitcastToAPInt().getZExtValue(); @@ -457,18 +428,19 @@ Tensor makeTensor(DenseElementsAttr attr) { auto elementType = type.getElementType(); // Handle floating-point types. - if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() || - elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() || - elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() || - elementType.isFloat8E5M2FNUZ()) { + if (isSupportedFloatType(elementType) && + cast(elementType).getWidth() <= 8) { auto floatValues = llvm::map_to_vector( attr.getValues(), [&](APFloat value) -> uint8_t { return value.bitcastToAPInt().getZExtValue(); }); // For f8E3M4, f8E4M3, f8E4M3FN, f8E4M3FNUZ, f8E4M3B11FNUZ, f8E5M2, and - // f8E5M2FNUZ floating-point types, we use uint8_t as their storage type - // because there are no builtin types for those. + // f8E5M2FNUZ, f8E8M0FNU floating-point types, we use uint8_t as their + // storage type because there are no builtin types for those. + // For f4E2M1FN, f6E2M3FN, and f6E3M2FN floating-point types, we still use + // uint8_t, even though the underlying types require less bits (similar + // to how ui2/ui4 types are handled). return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( floatValues)); } diff --git a/stablehlo/reference/Types.cpp b/stablehlo/reference/Types.cpp index 9944ca07c1a..6f06a12de42 100644 --- a/stablehlo/reference/Types.cpp +++ b/stablehlo/reference/Types.cpp @@ -48,10 +48,12 @@ bool isSupportedIntegerType(Type type) { } bool isSupportedFloatType(Type type) { - return type.isFloat8E3M4() || type.isFloat8E4M3B11FNUZ() || - type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ() || type.isF16() || type.isBF16() || + return type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || + type.isFloat6E3M2FN() || type.isFloat8E3M4() || + type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3() || + type.isFloat8E4M3FN() || type.isFloat8E4M3FNUZ() || + type.isFloat8E5M2() || type.isFloat8E5M2FNUZ() || + type.isFloat8E8M0FNU() || type.isF16() || type.isBF16() || type.isF32() || type.isF64(); } diff --git a/stablehlo/tests/interpret/constant.mlir b/stablehlo/tests/interpret/constant.mlir index 2e24ba02f81..82f042dace3 100644 --- a/stablehlo/tests/interpret/constant.mlir +++ b/stablehlo/tests/interpret/constant.mlir @@ -96,6 +96,30 @@ func.func @constant_op_test_ui64() { // ----- +func.func @constant_op_test_f4_e2m1fn() { + %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x07, 0x0F, 0x01, 0x09]> : tensor<10xf4E2M1FN> + check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.0, 0.0, 3.0, 6.0, -6.0, 0.5, -0.5]> : tensor<10xf4E2M1FN> + func.return +} + +// ----- + +func.func @constant_op_test_f6_e2m3fn() { + %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x1F, 0x3F, 0x01, 0x21]> : tensor<10xf6E2M3FN> + check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.125, 3.25, 7.5, -7.5, 0.125, -0.125]> : tensor<10xf6E2M3FN> + func.return +} + +// ----- + +func.func @constant_op_test_f6_e3m2fn() { + %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x1F, 0x3F, 0x01, 0x21]> : tensor<10xf6E3M2FN> + check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.125, 3.0, 28.0, -28.0, 0.0625, -0.0625]> : tensor<10xf6E3M2FN> + func.return +} + +// ----- + func.func @constant_op_test_f8_e3m4() { %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E3M4> check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.09375, 3.125, 0x7F, 0xFF, 0.015625, -0.015625]> : tensor<10xf8E3M4> @@ -159,6 +183,14 @@ func.func @constant_op_test_f8_e5m2_fnuz() { // ----- +func.func @constant_op_test_f8_e8m0fnu() { + %0 = stablehlo.constant dense<[0.0, 1.0, 0.125, 0.1, 3.1415, 0x00, 0x80, 0xFF]> : tensor<8xf8E8M0FNU> + check.expect_almost_eq_const %0, dense<[0.0, 1.0, 0.125, 0.125, 4.0, 1.175490e-38, 2.0, 0xFF]> : tensor<8xf8E8M0FNU> + func.return +} + +// ----- + func.func @constant_op_test_bf16() { %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.140630, 0x7F80, 0xFF80, 0x7FFF, 0x0001, 0x8001]> : tensor<11xbf16> check.expect_almost_eq_const %0, dense<[0.000000e+00, -0.000000e+00, 1.000000e+00, 1.250000e-01, 1.000980e-01, 3.140630e+00, 0x7F80, 0xFF80, 0x7FFF, 9.183550e-41, -9.183550e-41]> : tensor<11xbf16> diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 736ff25da15..a5aa2f359dc 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -2190,7 +2190,7 @@ func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2199,7 +2199,7 @@ func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) - func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2217,7 +2217,7 @@ func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor) { %cst = stablehlo.constant dense<7> : tensor<1xi64> - // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> func.return } @@ -2252,7 +2252,7 @@ func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %ar func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2262,7 +2262,7 @@ func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> ten func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2280,7 +2280,7 @@ func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> t func.func @rng_uniform_invalid_type(%a: tensor>, %b: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2828,7 +2828,7 @@ func.func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- func.func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // expected-error@+1 {{must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4xi32>'}} + // expected-error@+1 {{must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4xi32>'}} %0 = "stablehlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -6137,7 +6137,7 @@ func.func @is_finite(%arg0: tensor<3xf32>) -> tensor<3xi1> { // ----- func.func @is_finite_int_input(%arg0: tensor<3xi32>) -> tensor<3xi1> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3xi32>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3xi32>'}} %0 = "stablehlo.is_finite"(%arg0) {} : (tensor<3xi32>) -> tensor<3xi1> func.return %0 : tensor<3xi1> } @@ -6185,6 +6185,30 @@ func.func @convert(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: func @convert_f4e2m1fn +func.func @convert_f4e2m1fn(%arg0: tensor) -> tensor { + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @convert_f6e2m3fn +func.func @convert_f6e2m3fn(%arg0: tensor) -> tensor { + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @convert_f6e3m2fn +func.func @convert_f6e3m2fn(%arg0: tensor) -> tensor { + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @convert_f8e3m4 func.func @convert_f8e3m4(%arg0: tensor) -> tensor { %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor @@ -6233,6 +6257,14 @@ func.func @f8e5m2fnuz(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: func @f8e8m0fnu +func.func @f8e8m0fnu(%arg0: tensor) -> tensor { + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @dynamic_iota_static() -> tensor<4xf32> { %0 = stablehlo.constant dense<[4]> : tensor<1xi64> %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32> diff --git a/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/tests/ops_stablehlo_quantized.mlir index 482eec58756..22366e45e0d 100644 --- a/stablehlo/tests/ops_stablehlo_quantized.mlir +++ b/stablehlo/tests/ops_stablehlo_quantized.mlir @@ -380,7 +380,7 @@ func.func @while_per_tensor_quantization(%arg0: tensor<4x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %abs_neg = "stablehlo.abs"(%arg0) : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>> func.return } @@ -388,7 +388,7 @@ func.func @negative_abs_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x4x!quant.uniform>'}} %all_gather = "stablehlo.all_gather"(%arg0) { all_gather_dim = 1 : i64, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> } : (tensor<2x4x!quant.uniform>) -> tensor<2x4x!quant.uniform> func.return } @@ -396,7 +396,7 @@ func.func @negative_all_gather_quantization(%arg0: tensor<2x4x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x4x!quant.uniform>'}} %all_to_all = "stablehlo.all_to_all"(%arg0) { split_dimension = 1 : i64, concat_dimension = 1 : i64, split_count = 2 : i64, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, channel_handle = #stablehlo.channel_handle} : (tensor<2x4x!quant.uniform>) -> tensor<2x4x!quant.uniform> func.return } @@ -404,7 +404,7 @@ func.func @negative_all_to_all_quantization(%arg0: tensor<2x4x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %atan2 = "stablehlo.atan2"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -412,7 +412,7 @@ func.func @negative_atan_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %cbrt = "stablehlo.cbrt"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -420,7 +420,7 @@ func.func @negative_bitcast_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %ceil = "stablehlo.ceil"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -428,7 +428,7 @@ func.func @negative_ceil_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %cholesky = "stablehlo.cholesky"(%arg0) { lower = true } : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -437,7 +437,7 @@ func.func @negative_cholesky_quantization(%arg0: tensor<1x2x2x!quant.uniform>) -> tensor<1x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x!quant.uniform>'}} %0 = "stablehlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> func.return %0: tensor<1x!quant.uniform> } @@ -445,7 +445,7 @@ func.func @negative_clamp_quantization(%arg0: tensor<1x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %collective_permute = "stablehlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, channel_handle = #stablehlo.channel_handle} : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -453,7 +453,7 @@ func.func @negative_collective_permute_quantization(%arg0: tensor<1x2x2x!quant.u // ----- func.func @negative_compare_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %compare = "stablehlo.compare"(%arg0, %arg1) { comparison_direction = #stablehlo, compare_type = #stablehlo } : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2xi1> func.return } @@ -461,7 +461,7 @@ func.func @negative_compare_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %concatenate = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<2x2x2x!quant.uniform> func.return } @@ -469,7 +469,7 @@ func.func @negative_concatenate_quantization(%arg0: tensor<1x2x2x!quant.uniform< // ----- func.func @negative_cosine_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %cosine = "stablehlo.cosine"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -477,7 +477,7 @@ func.func @negative_cosine_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %divide = "stablehlo.divide"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -485,7 +485,7 @@ func.func @negative_divide_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor, %arg2: tensor) -> tensor<1x4x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} %0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = array} : (tensor<3x4x!quant.uniform>, tensor, tensor) -> tensor<1x4x!quant.uniform> func.return %0 : tensor<1x4x!quant.uniform> } @@ -493,7 +493,7 @@ func.func @negative_dynamic_slice_quantization(%arg0: tensor<3x4x!quant.uniform< // ----- func.func @negative_exponential_minus_one_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %exponential_minus_one = "stablehlo.exponential_minus_one"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -501,7 +501,7 @@ func.func @negative_exponential_minus_one_quantization(%arg0: tensor<1x2x2x!quan // ----- func.func @negative_exponential_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %exponential_minus_one = "stablehlo.exponential"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -509,7 +509,7 @@ func.func @negative_exponential_quantization(%arg0: tensor<1x2x2x!quant.uniform< // ----- func.func @negative_floor_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %floor = "stablehlo.floor"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -517,7 +517,7 @@ func.func @negative_floor_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %is_finite = "stablehlo.is_finite"(%arg0) {} : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2xi1> func.return } @@ -525,7 +525,7 @@ func.func @negative_floor_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %log_plus_one = "stablehlo.log_plus_one"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -533,7 +533,7 @@ func.func @negative_log_plus_one_quantization(%arg0: tensor<1x2x2x!quant.uniform // ----- func.func @negative_logistic_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %logistic = "stablehlo.logistic"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -541,7 +541,7 @@ func.func @negative_logistic_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %log = "stablehlo.log"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -549,7 +549,7 @@ func.func @negative_log_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x!quant.uniform>'}} %map = "stablehlo.map"(%arg0, %arg1) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): "stablehlo.return"(%arg2) : (tensor>) -> () @@ -560,7 +560,7 @@ func.func @negative_map_quantization(%arg0: tensor<4x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %maximum = "stablehlo.maximum"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -568,7 +568,7 @@ func.func @negative_maximum_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %minimum = "stablehlo.minimum"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -576,7 +576,7 @@ func.func @negative_minimum_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %multiply = "stablehlo.multiply"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -584,7 +584,7 @@ func.func @negative_multiply_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %negate = "stablehlo.negate"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -592,14 +592,14 @@ func.func @negative_negate_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values or token, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values or token, but got 'tensor<1x2x2x!quant.uniform>'}} %optimization_barrier = "stablehlo.optimization_barrier"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } // ----- func.func @negative_pad_quantization(%arg0: tensor<1x2x3x!quant.uniform>, %arg1: tensor>) -> tensor<2x4x7x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x3x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x3x!quant.uniform>'}} %pad = "stablehlo.pad"(%arg0, %arg1) { edge_padding_low = array, edge_padding_high = array, @@ -611,7 +611,7 @@ func.func @negative_pad_quantization(%arg0: tensor<1x2x3x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %power = "stablehlo.power"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -619,7 +619,7 @@ func.func @negative_power_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor>) -> tensor> { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x!quant.uniform>'}} %reduce = "stablehlo.reduce"(%arg0, %arg1) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %1 = "stablehlo.add"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> @@ -633,7 +633,7 @@ func.func @reduce_quantization(%arg0: tensor<16x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %remainder = "stablehlo.remainder"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -641,7 +641,7 @@ func.func @negative_remainder_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %rsqrt = "stablehlo.rsqrt"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -649,7 +649,7 @@ func.func @negative_rsqrt_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %sine = "stablehlo.sine"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -657,7 +657,7 @@ func.func @negative_sine_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %sqrt = "stablehlo.sqrt"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -665,7 +665,7 @@ func.func @negative_sqrt_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %subtract = "stablehlo.subtract"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -673,7 +673,7 @@ func.func @negative_subtract_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %tanh = "stablehlo.tanh"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -681,7 +681,7 @@ func.func @negative_tanh_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %scale: tensor<2x!quant.uniform>, %mean: tensor<2x!quant.uniform>, %variance: tensor<2x!quant.uniform>, %grad_output: tensor<2x2x2x2x!quant.uniform>) -> tensor<2x2x2x2x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x2x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x2x2x2x!quant.uniform>'}} %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x2x2x2x!quant.uniform>) -> (tensor<2x2x2x2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) @@ -691,7 +691,7 @@ func.func @negative_batch_norm_grad_quantization(%input: tensor<2x2x2x2x!quant.u // ----- func.func @negative_batch_norm_inference_quantization(%input: tensor<4x256x!quant.uniform>, %scale: tensor<256x!quant.uniform>, %offset: tensor<256x!quant.uniform>, %mean: tensor<256x!quant.uniform>, %variance: tensor<256x!quant.uniform>) -> (tensor<4x256x!quant.uniform>) { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x256x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x256x!quant.uniform>'}} %0 = "stablehlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) { epsilon = 1.001000e-05 : f32, feature_index = 1 : i64 @@ -701,7 +701,7 @@ func.func @negative_batch_norm_inference_quantization(%input: tensor<4x256x!quan // ----- func.func @negative_batch_norm_training_quantization(%input: tensor<2x2x2x2x!quant.uniform>, %scale: tensor<2x!quant.uniform>, %offset: tensor<2x!quant.uniform>) -> tensor<2x2x2x2x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x2x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x2x2x2x!quant.uniform>'}} %0:3 = "stablehlo.batch_norm_training" (%input, %scale, %offset) { epsilon = 0.001 : f32, feature_index = 1 : i64 @@ -712,7 +712,7 @@ func.func @negative_batch_norm_training_quantization(%input: tensor<2x2x2x2x!qua // ----- func.func @negative_dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x3x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x3x4x!quant.uniform>'}} %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0], @@ -727,7 +727,7 @@ func.func @negative_dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform< // ----- func.func @negative_dynamic_update_slice_pertensor_quantization(%operand: tensor<3x4x!quant.uniform>, %update: tensor<1x4x!quant.uniform>, %start_indices0: tensor, %start_indices1: tensor) -> tensor<3x4x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} %0 = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1) : (tensor<3x4x!quant.uniform>, tensor<1x4x!quant.uniform>, tensor, tensor) -> tensor<3x4x!quant.uniform> func.return %0 : tensor<3x4x!quant.uniform> } @@ -735,7 +735,7 @@ func.func @negative_dynamic_update_slice_pertensor_quantization(%operand: tensor // ----- func.func @negative_gather_quantization(%operand : tensor<*x!quant.uniform>, %start_indices : tensor<1x5x2xi32>) -> tensor<8x?x7x1x6x1x?x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<*x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<*x!quant.uniform>'}} %res = "stablehlo.gather"(%operand, %start_indices) { dimension_numbers = #stablehlo.gather< offset_dims = [0, 2, 3, 4, 5], @@ -752,7 +752,7 @@ func.func @negative_gather_quantization(%operand : tensor<*x!quant.uniform>) -> tensor<6x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<6x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<6x!quant.uniform>'}} %output = "stablehlo.reduce_precision"(%arg0) { exponent_bits = 5 : i32, mantissa_bits = 10 : i32 @@ -763,7 +763,7 @@ func.func @negative_reduce_precision_quantization(%arg0: tensor<6x!quant.uniform // ----- func.func @negative_reduce_scatter_quantization(%data: tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x16x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x16x!quant.uniform>'}} %0 = "stablehlo.reduce_scatter"(%data) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %1 = stablehlo.add %arg2, %arg3 : tensor> @@ -778,7 +778,7 @@ func.func @negative_reduce_scatter_quantization(%data: tensor<4x16x!quant.unifor // ----- func.func @negative_reduce_window_quantization(%arg0: tensor<2x17x31x7x!quant.uniform>, %arg1: tensor>) -> tensor<2x9x16x7x!quant.uniform> { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x17x31x7x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x17x31x7x!quant.uniform>'}} %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> @@ -796,7 +796,7 @@ func.func @negative_reduce_window_quantization(%arg0: tensor<2x17x31x7x!quant.un // ----- func.func @negative_reverse_quantization(%operand: tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x2x!quant.uniform>'}} %result = "stablehlo.reverse"(%operand) { dimensions = array } : (tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> @@ -806,7 +806,7 @@ func.func @negative_reverse_quantization(%operand: tensor<3x2x!quant.uniform>) -> tensor<2x!quant.uniform> { - // expected-error@+1 {{ operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x!quant.uniform>'}} + // expected-error@+1 {{ operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x!quant.uniform>'}} %0 = "stablehlo.round_nearest_afz"(%arg0) {} : (tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> func.return %0 : tensor<2x!quant.uniform> } @@ -814,7 +814,7 @@ func.func @negative_round_afz(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { - // expected-error@+1 {{ operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x!quant.uniform>'}} + // expected-error@+1 {{ operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x!quant.uniform>'}} %0 = "stablehlo.round_nearest_even"(%arg0) {} : (tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> func.return %0 : tensor<2x!quant.uniform> } @@ -822,7 +822,7 @@ func.func @negative_round_even(%arg0: tensor<2x!quant.uniform>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300x!quant.uniform>) -> tensor<200x100x300x!quant.uniform> { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<200x100x300x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<200x100x300x!quant.uniform>'}} %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor>, %arg4: tensor>): %1 = "stablehlo.add"(%arg3, %arg4) : (tensor>, tensor>) -> tensor> @@ -841,7 +841,7 @@ func.func @negative_scatter_quantization(%arg0: tensor<200x100x300x!quant.unifor // ----- func.func @negative_select_quantization(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3x!quant.uniform>, %arg2: tensor<2x3x!quant.uniform>) -> tensor<2x3x!quant.uniform> { - // expected-error@+1 {{operand #1 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x3x!quant.uniform>'}} + // expected-error@+1 {{operand #1 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x3x!quant.uniform>'}} %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3x!quant.uniform>, tensor<2x3x!quant.uniform>) -> tensor<2x3x!quant.uniform> func.return %0 : tensor<2x3x!quant.uniform> } @@ -849,7 +849,7 @@ func.func @negative_select_quantization(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3 // ----- func.func @negative_slice_quantization(%arg0: tensor<3x4x!quant.uniform>) -> tensor<1x2x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} %0 = "stablehlo.slice"(%arg0) {start_indices = array, limit_indices = array, strides = array} : (tensor<3x4x!quant.uniform>) -> tensor<1x2x!quant.uniform> func.return %0 : tensor<1x2x!quant.uniform> } @@ -858,7 +858,7 @@ func.func @negative_slice_quantization(%arg0: tensor<3x4x!quant.uniform>, %input1: tensor<16x16x!quant.uniform>) { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x16x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x16x!quant.uniform>'}} %0:2 = "stablehlo.sort"(%input0, %input1) ({ ^bb0(%arg0: tensor>, %arg1: tensor>, %arg2: tensor>, %arg3: tensor>): %7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo} : (tensor>, tensor>) -> tensor @@ -870,7 +870,7 @@ func.func @negative_sort_quantization(%input0: tensor<16x16x!quant.uniform>, %arg1: tensor<10x23x23x64x!quant.uniform>, %arg2: tensor>) -> tensor<10x24x24x64x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<10x24x24x64x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<10x24x24x64x!quant.uniform>'}} %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor>, %arg4: tensor>): %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor>, tensor>) -> tensor diff --git a/stablehlo/tests/ops_stablehlo_roundtrip.mlir b/stablehlo/tests/ops_stablehlo_roundtrip.mlir index ab11c086f1d..7be3c843938 100644 --- a/stablehlo/tests/ops_stablehlo_roundtrip.mlir +++ b/stablehlo/tests/ops_stablehlo_roundtrip.mlir @@ -183,6 +183,9 @@ func.func @test_constants() { %cst_4 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %cst_5 = arith.constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32> %cst_6 = arith.constant dense<[[1, 2], [4, 8]]> : tensor<2x2xui32> + %cst_18 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf4E2M1FN> + %cst_19 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf6E2M3FN> + %cst_20 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf6E3M2FN> %cst_17 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E3M4> %cst_7 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E4M3B11FNUZ> %cst_16 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E4M3> @@ -190,6 +193,7 @@ func.func @test_constants() { %cst_9 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E4M3FNUZ> %cst_10 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E5M2> %cst_11 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E5M2FNUZ> + %cst_21 = arith.constant dense<[1.0, 2.0, 4.0, 8.0]> : tensor<4xf8E8M0FNU> %cst_12 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> %cst_13 = arith.constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16> %cst_14 = arith.constant dense<(1.000000e+00,0.000000e+00)> : tensor> diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir new file mode 100644 index 00000000000..72df0a64130 --- /dev/null +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir @@ -0,0 +1,2936 @@ +// RUN: stablehlo-opt --mlir-print-op-generic %s.bc | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-translate --serialize --target=1.8.0 | stablehlo-opt --mlir-print-op-generic | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-opt > %t.0 +// RUN: stablehlo-opt --strip-debuginfo %s > %t.1 +// RUN: diff %t.0 %t.1 +// RUN: stablehlo-translate --serialize --target=1.8.0 --strip-debuginfo %s > %t.2 +// RUN: diff %s.bc %t.2 +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode -debug-only=vhlo-bytecode %s 2>&1 | FileCheck --check-prefix=CHECK-WARN %s +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode %s | stablehlo-opt -debug-only=vhlo-bytecode 2>&1 | FileCheck --check-prefix=CHECK-WARN %s + +// CHECK-WARN-NOT: Not Implemented + +// ============ ATTRIBUTES ============ + +// CHECK-LABEL: "attr_comparison_direction_eq" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_eq(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ne" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_ne(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ge" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_ge(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_gt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_gt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_le" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_le(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_lt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_lt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_notype" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_notype(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + // CHECK: compare_type = #vhlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_float" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_float(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_totalorder" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_totalorder(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_signed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_signed(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_unsigned" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ConvDimensionNumbers aka #stablehlo.conv is covered below. + +// CHECK-LABEL: "attr_custom_call_api_version_unspecified" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_unspecified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 0 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_original" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_original(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 1 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_status_returning(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 2 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning_unified" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 3 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_dict" +// CHECK: #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = #vhlo.integer_v1<1 : i32>, #vhlo.string_v1<"attr2"> = #vhlo.integer_v1<2 : i32>} +func.func @attr_dict() attributes {stablehlo.attr = {attr1 = 1 : i32, attr2 = 2 : i32}} { + return +} + +// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +// CHECK: api_version = #vhlo +// CHECK-SAME: backend_config = #vhlo.dict_v1<{#vhlo.string_v1<"bar"> = #vhlo.integer_v1<42 : i32>}> +func.func @attr_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + backend_config= {bar = 42 : i32}, + api_version = 4 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + + +// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi_no_backend_config" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +// CHECK: api_version = #vhlo +// CHECK-SAME: backend_config = #vhlo.dict_v1<{}> +func.func @attr_custom_call_api_version_typed_ffi_no_backend_config(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + api_version = 4 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// DotDimensionNumbers aka #stablehlo.dot is covered below. + +// CHECK-LABEL: "attr_fft_type_fft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_ifft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_rfft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xf32>) -> tensor<9xcomplex> + func.return %0 : tensor<9xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_irfft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<9xcomplex>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// GatherDimensionNumbers aka #stablehlo.gather is covered below. + +// CHECK-LABEL: "attr_precision_config_default" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_high" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_highest" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_rng_algorithm_default" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_three_fry" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_rng_algorithm_three_fry(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_philox" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_distribution_uniform" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor<0xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_rng_distribution_normal" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor<0xindex>) -> tensor + func.return %0 : tensor +} + +// ScatterDimensionNumbers aka #stablehlo.scatter is covered below. + +// CHECK-LABEL: "attr_transpose_no_transpose" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_transpose_no_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_transpose" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_transpose_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_adjoint" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. + +// CHECK-LABEL: "attr_type_extensions_bounds" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_type_extensions_bounds(%arg0: tensor>) -> tensor> { + // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> () + func.return %arg0 : tensor> +} + + +// ============ DEFAULTS ============ + +// CHECK-LABEL: "default_all_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "default_all_gather_variadic" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_gather_variadic(%arg0: tensor<16x8xf32>, %arg1: tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) { + %0:2 = "stablehlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16x8xf32>, tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) + func.return %0#0, %0#1 : tensor<16x16xf32>, tensor<16x16xf32> +} + +// CHECK-LABEL: "default_all_reduce" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) + // CHECK-SAME: <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_all_to_all" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "default_all_to_all_variadic" +func.func @default_all_to_all_variadic(%arg0: tensor<4x16xf32>, %arg1: tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) { + %0:2 = "stablehlo.all_to_all"(%arg0, %arg1) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4x16xf32>, tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) + func.return %0#0, %0#1 : tensor<16x4xf32>, tensor<20x4xf32> +} + +// CHECK-LABEL: "default_cholesky" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "default_collective_permute" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_collective_broadcast" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_collective_broadcast(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_broadcast_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_compare" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_composite" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_composite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ + // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{}> + // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> + // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> + // CHECK-SAME: version = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.composite"(%arg0) { + name = "stablehlo.composite_target", + decomposition = @composite_target + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_convolution" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { + // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x6x6x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> + func.return %0 : tensor<1x6x6x16xf32> +} + +// CHECK-LABEL: "default_custom_call" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dot_general" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: accumulation_type = #vhlo.type_v1, + // CHECK-SAME: allow_imprecise_accumulation = #vhlo.type_v1, + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_component_count = #vhlo.type_v1, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_precision_type = #vhlo.type_v1, + // CHECK-SAME: num_primitive_operations = #vhlo.type_v1, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_component_count = #vhlo.type_v1, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_precision_type = #vhlo.type_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "dot_general_algorithm" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @dot_general_algorithm(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { +// CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ +// CHECK-SAME: accumulation_type = #vhlo.type_v1, +// CHECK-SAME: allow_imprecise_accumulation = #vhlo.bool_v1, +// CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, +// CHECK-SAME: lhs_component_count = #vhlo.integer_v1<1 : i64>, +// CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, +// CHECK-SAME: lhs_precision_type = #vhlo.type_v1, +// CHECK-SAME: num_primitive_operations = #vhlo.integer_v1<1 : i64>, +// CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, +// CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, +// CHECK-SAME: rhs_component_count = #vhlo.integer_v1<1 : i64>, +// CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, +// CHECK-SAME: rhs_precision_type = #vhlo.type_v1 +// CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = tf32, + rhs_precision_type = tf32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "default_dynamic_broadcast_in_dim" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = array + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dynamic_conv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "default_dynamic_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + > + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +func.func @default_func(%arg0: tensor) -> tensor { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"default_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<""> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + func.return %arg0 : tensor +} + +// CHECK-LABEL: "default_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = array + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "default_infeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_outfeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<""> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_recv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_send" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_reduce_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "default_reduce_window" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x16x30x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + func.return %0 : tensor<2x16x30x7xf32> +} + +// CHECK-LABEL: "default_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + > + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK-LABEL: "default_select_and_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<10x23x23x64x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array + } : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "default_sort" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<-1 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// ============ OPS ============ + +// CHECK-LABEL: "op_abs" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_abs(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_add" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_after_all" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.after_all_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.after_all"(%arg0) : (!stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_all_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_all_reduce" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_all_reduce_with_promotable_types" +func.func @op_all_reduce_with_promotable_types(%operand: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v2"(%[[ARG0:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + + func.return %result : tensor +} + +// CHECK-LABEL: "default_all_reduce_variadic" +func.func @default_all_reduce_variadic(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.all_reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> (tensor) + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor, tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "op_all_to_all" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "op_and" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_atan2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.atan2_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_batch_norm_grad" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) +func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_grad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_batch_norm_inference" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) +func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { + // CHECK: "vhlo.batch_norm_inference_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1> + %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> + func.return %0 : tensor<16x16x16x16xf32> +} + +// CHECK-LABEL: "op_batch_norm_training" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_training_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_bitcast_convert" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_bitcast_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.bitcast_convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.bitcast_convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_broadcast_in_dim" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_in_dim_v1"(%[[ARG0]]) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_broadcast" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_v1"(%[[ARG0]]) <{ + // CHECK-SAME: broadcast_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast"(%arg0) { + broadcast_sizes = array + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_case" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_case(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cbrt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cbrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cbrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cbrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_ceil" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_ceil(%arg0: tensor) -> tensor { + // CHECK: "vhlo.ceil_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cholesky" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) { + lower = true + } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "op_clamp" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.clamp_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_count_leading_zeros" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { + // CHECK: "vhlo.count_leading_zeros_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.count_leading_zeros"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_collective_permute" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "op_compare" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_complex" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> { + // CHECK: "vhlo.complex_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_composite" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_composite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ + // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{#vhlo.string_v1<"my_int"> = #vhlo.integer_v1<1 : i64>, #vhlo.string_v1<"my_string"> = #vhlo.string_v1<"foo">}> + // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> + // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> + // CHECK-SAME: version = #vhlo.integer_v1<1 : i32> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.composite"(%arg0) { + name = "stablehlo.composite_target", + decomposition = @composite_target, + version = 1 : i32, + composite_attributes = { + my_string = "foo", + my_int = 1 : i64 + } + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_concatenate" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.concatenate_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1<8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.concatenate"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_constant" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_constant(%arg0: tensor) -> tensor { + // CHECK: "vhlo.constant_v1"() <{ + // CHECK-SAME: value = #vhlo.tensor_v1 : tensor> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1 + %0 = "stablehlo.constant"() { + value = dense<0.0> : tensor + } : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convert" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convolution" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> { + // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x7x7x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = array, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> + func.return %0 : tensor<1x7x7x16xf32> +} + +// CHECK-LABEL: "op_cosine" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cosine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cosine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cosine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_create_token" +func.func @op_create_token() -> !stablehlo.token { + // CHECK: "vhlo.create_token_v1"() : () -> !vhlo.token_v1 + %0 = "stablehlo.create_token"() : () -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_cross_replica_sum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cross-replica-sum_v1"(%[[ARG0]]) <{ + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cross-replica-sum"(%arg0) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_custom_call" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"\08\03\1A\02">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[#vhlo.string_v1<"foo">]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[ + // CHECK-SAME: #vhlo.output_operand_alias_v1< + // CHECK-SAME: outputTupleIndices = [], + // CHECK-SAME: operandIndex = 0, + // CHECK-SAME: operandTupleIndices = []>]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + has_side_effect = true, + backend_config = "\08\03\1A\02", + api_version = 2 : i32, + called_computations = [@foo], + operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #stablehlo.output_operand_alias], + result_layouts = [dense<> : tensor<0xindex>] + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_custom_call_empty_result_layout" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func public @op_custom_call_empty_result_layout(%arg0: tensor) -> tensor { + // %0 = "vhlo.custom_call_v1"(%arg0) <{>}> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> + // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"empty_output">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]>, + // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> + %0 = "stablehlo.custom_call"(%arg0) <{ + api_version = 2 : i32, + call_target_name = "empty_output", + has_side_effect = true, + operand_layouts = [dense<> : tensor<0xindex>], + result_layouts = [] + }> : (tensor) -> tuple<> + return %arg0 : tensor +} + +// CHECK-LABEL: "op_divide" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.divide_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dot_general" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: accumulation_type = #vhlo.type_v1, + // CHECK-SAME: allow_imprecise_accumulation = #vhlo.type_v1, + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_component_count = #vhlo.type_v1, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_precision_type = #vhlo.type_v1, + // CHECK-SAME: num_primitive_operations = #vhlo.type_v1, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_component_count = #vhlo.type_v1, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_precision_type = #vhlo.type_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "op_dot" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) { + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_dynamic_broadcast_in_dim" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_conv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + window_strides = array, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "op_dynamic_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +// CHECK-LABEL: "op_dynamic_gather_with_batching_dims" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<4xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [1, 2], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [1, 2], + index_vector_dim = 2 + >, + indices_are_sorted = true + } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>, tensor<4xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +// CHECK-LABEL: "op_dynamic_iota" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_iota_v1"(%[[ARG0]]) <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_iota"(%arg0) { + iota_dimension = 0 : i64 + } : (tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_pad" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) +func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_pad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_reshape" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { + // CHECK: "vhlo.dynamic_slice_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { + slice_sizes = array + } : (tensor<16xf32>, tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_dynamic_update_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.dynamic_update_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_einsum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.einsum_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab,bc->ac"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.einsum"(%arg0, %arg1) { + einsum_config = "ab,bc->ac" + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_exponential_minus_one" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_minus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_exponential" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_exponential(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_fft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + // CHECK: "vhlo.fft_v1"(%[[ARG0]]) <{ + // CHECK-SAME: fft_length = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: fft_type = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.complex_v1>) -> !vhlo.tensor_v1<16x!vhlo.complex_v1> + %0 = "stablehlo.fft"(%arg0) { + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "op_floor" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_floor(%arg0: tensor) -> tensor { + // CHECK: "vhlo.floor_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +func.func private @op_func(%arg0: tensor {stablehlo.arg = "0"}) -> (tensor {stablehlo.result = "0"}) { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.arg"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.result"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"op_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<"private"> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + + func.return %arg0 : tensor +} + +// CHECK-LABEL: "op_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "op_gather_with_batching_dims" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [1, 2], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [1, 2], + index_vector_dim = 2 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "op_get_dimension_size" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_get_dimension_size(%arg0: tensor) -> tensor { + // CHECK: "vhlo.get_dimension_size_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_dimension_size"(%arg0) { + dimension = 0 : i64 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_get_tuple_element" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_get_tuple_element(%arg0: tuple, tensor>) -> tensor { + // CHECK: "vhlo.get_tuple_element_v1"(%[[ARG0]]) <{ + // CHECK-SAME: index = #vhlo.integer_v1<0 : i32> + // CHECK-SAME: }> : (!vhlo.tuple_v1, !vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_tuple_element"(%arg0) { + index = 0 : i32 + } : (tuple, tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_if" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.if_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG2]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.if"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + "stablehlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_imag" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_imag(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.imag_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.imag"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_infeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"foo">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[#vhlo.array_v1<[]>]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) { + infeed_config = "foo", + layout = [[]] + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_iota" +func.func @op_iota() -> tensor<16xf32> { + // CHECK: "vhlo.iota_v1"() <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.iota"() { + iota_dimension = 0 : i64 + } : () -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_is_finite" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_is_finite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.is_finite_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.is_finite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_log(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log_plus_one" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_log_plus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_plus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log_plus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_logistic" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_logistic(%arg0: tensor) -> tensor { + // CHECK: "vhlo.logistic_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.logistic"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_map" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.map_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.abs_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.map"(%arg0) ({ + ^bb0(%arg1: tensor): + %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = array + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_maximum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_maximum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.maximum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_minimum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_minimum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.minimum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_multiply" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_multiply(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.multiply_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_negate" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_negate(%arg0: tensor) -> tensor { + // CHECK: "vhlo.negate_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.negate"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_not" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_not(%arg0: tensor) -> tensor { + // CHECK: "vhlo.not_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.not"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_optimization_barrier" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_optimization_barrier(%arg0: tensor) -> tensor { + // CHECK: "vhlo.optimization_barrier_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.optimization_barrier"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_or" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.or_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_outfeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<"foo"> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) { + outfeed_config = "foo" + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_pad" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.pad_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: edge_padding_high = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: edge_padding_low = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: interior_padding = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.pad"(%arg0, %arg1) { + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array + } : (tensor<8xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_popcnt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_popcnt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.popcnt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_power" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_power(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.power_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real_dynamic_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}) +func.func @op_real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.real_dynamic_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_real(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.real_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_recv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_reduce" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.reduce_v1"(%[[ARG0]], %[[ARG1]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = array + } : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reduce_precision" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reduce_precision(%arg0: tensor) -> tensor { + // CHECK: "vhlo.reduce_precision_v1"(%[[ARG0]]) <{ + // CHECK-SAME: exponent_bits = #vhlo.integer_v1<8 : i32> + // CHECK-SAME: mantissa_bits = #vhlo.integer_v1<10 : i32> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.reduce_precision"(%arg0) { + exponent_bits = 8 : i32, + mantissa_bits = 10 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK_lABEL: "op_reduce_with_promotable_types" +func.func @op_reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + // CHECK: "vhlo.reduce_v1"(%[[ARG0:.*]], %[[ARG1:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f64_v1> + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = array} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> + + func.return %0: tensor<4xf64> +} + +// CHECK-LABEL: "op_reduce_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK_lABEL: "op_reduce_scatter_with_promotable_types" +func.func @op_reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f64_v1> + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + + +// CHECK-LABEL: "op_reduce_window" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x9x16x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x9x16x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array, + padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x9x16x7xf32> + func.return %0 : tensor<2x9x16x7xf32> +} + +// CHECK-LABEL: "op_reduce_window_with_promotable_types" +func.func @op_reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + // CHECK: "vhlo.reduce_window_v1"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]], %[[VAL2:.*]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1, !vhlo.tensor_v1) -> (!vhlo.tensor_v1<2x2x!vhlo.f64_v1>, !vhlo.tensor_v1<2x2x!vhlo.f32_v1>) + %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = stablehlo.add %a0, %b0 : tensor + %3 = stablehlo.add %a1, %b1 : tensor + "stablehlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = array, + window_strides = array } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + +// CHECK-LABEL: "op_remainder" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.remainder_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_replica_id" +func.func @op_replica_id() -> tensor { + // CHECK: "vhlo.replica_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.replica_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_partition_id" +func.func @op_partition_id() -> tensor { + // CHECK: "vhlo.partition_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.partition_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reshape" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { + // CHECK: "vhlo.reshape_v1"(%[[ARG0]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f32_v1> + %0 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// CHECK-LABEL: "op_return" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reverse" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reverse_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reverse"(%arg0) { + dimensions = array + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_rng_bit_generator" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { + // CHECK: "vhlo.rng_bit_generator_v1"(%[[ARG0]]) <{ + // CHECK-SAME: rng_algorithm = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> (!vhlo.tensor_v1, !vhlo.tensor_v1) + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "op_rng" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + // CHECK: "vhlo.rng_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: rng_distribution = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + rng_distribution = #stablehlo + } : (tensor, tensor, tensor<0xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_afz" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_round_nearest_afz(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_afz_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_afz"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_even" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_round_nearest_even(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_even_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_even"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_rsqrt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_rsqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.rsqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK-LABEL: "op_scatter_with_batching_dims" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_scatter_with_batching_dims(%arg0: tensor<10x200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<10x200x100x300xf32> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [1, 2], + input_batching_dims = [0], + scatter_dims_to_operand_dims = [1, 2], + scatter_indices_batching_dims = [0], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<10x200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<10x200x100x300xf32> + func.return %0 : tensor<10x200x100x300xf32> +} + +// CHECK_lABEL: "op_scatter_with_promotable_types" +func.func @op_scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f64_v1> + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> +} + +// CHECK-LABEL: "op_select_and_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "op_select_and_scatter_with_promotable_types" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_select_and_scatter_with_promotable_types(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf64> { + // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: %[[VAL:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK: "vhlo.return_v1"(%[[VAL]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f64_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf64> + func.return %0 : tensor<10x24x24x64xf64> +} + +// CHECK-LABEL: "op_select" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.select_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_send" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_set_dimension_size" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.set_dimension_size_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_shift_left" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_shift_left(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_left_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_arithmetic" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_shift_right_arithmetic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_arithmetic_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_logical" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_shift_right_logical(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_logical_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sign" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sign(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sign_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sign"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sine" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { + // CHECK: "vhlo.slice_v1"(%[[ARG0]]) <{ + // CHECK-SAME: limit_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: strides = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<16xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_sort" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimension = 0 : i64, + is_stable = true + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_sqrt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_subtract" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_subtract(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.subtract_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_tan" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_tan(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tan_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.tan"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_tanh" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_tanh(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tanh_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_torch_index_select" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { + // CHECK: "vhlo.torch_index_select_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: batch_dims = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: dim = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x1x5x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<2x1x5x!vhlo.f32_v1> + %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> + func.return %0 : tensor<2x1x5xf32> +} + +// CHECK-LABEL: "op_transpose" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { + // CHECK: "vhlo.transpose_v1"(%[[ARG0]]) <{ + // CHECK-SAME: permutation = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x16x!vhlo.f32_v1> + %0 = "stablehlo.transpose"(%arg0) { + permutation = array + } : (tensor<16x8xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: "op_triangular_solve" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.triangular_solve_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: left_side = #vhlo.bool_v1, + // CHECK-SAME: lower = #vhlo.bool_v1, + // CHECK-SAME: transpose_a = #vhlo, + // CHECK-SAME: unit_diagonal = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_tuple" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_tuple(%arg0: tensor) -> tuple> { + // CHECK: "vhlo.tuple_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tuple_v1> + %0 = "stablehlo.tuple"(%arg0) : (tensor) -> tuple> + func.return %0 : tuple> +} + +// CHECK-LABEL: "op_unary_einsum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { + // CHECK: "vhlo.unary_einsum_v1"(%[[ARG0]]) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab->a"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x!vhlo.f32_v1> + %0 = "stablehlo.unary_einsum"(%arg0) { + einsum_config = "ab->a" + } : (tensor<8x16xf32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} + +// CHECK-LABEL: "op_uniform_dequantize" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.uniform_dequantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_uniform_quantize" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_uniform_quantize(%arg0: tensor) -> tensor> { + // CHECK: "vhlo.uniform_quantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.uniform_quantize"(%arg0) : (tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_while" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_while(%arg0: tensor) -> tensor { + // CHECK: "vhlo.while_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1) + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.while"(%arg0) ({ + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0: tensor +} + +// CHECK-LABEL: "op_xor" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.xor_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ============ TYPES ============ + +// CHECK-LABEL: "type_i1" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i8" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui8" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f4E2M1FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f4E2M1FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f6E2M3FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f6E2M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f6E3M2FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f6E3M2FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E3M4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FNUZ" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3B11FNUZ" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3B11FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2FNUZ" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E8M0FNU" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E8M0FNU(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_bf16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_complex_f32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_complex_f32(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_complex_f64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_complex_f64(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_tf32" +// CHECK: #vhlo.type_v1 +func.func @type_tf32() attributes {stablehlo.attr = tf32 } { + return +} + +// CHECK-LABEL: "type_none" +// CHECK: #vhlo.type_v1 +func.func @type_none() attributes {stablehlo.attr = none } { + return +} + +// CHECK-LABEL: "type_dynamism_ranked" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_per_tensor_quantization" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_per_tensor_quantization(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_per_axis_quantization" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_per_axis_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG0]]) : (!vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>, !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>) -> !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1> + %0 = stablehlo.add %arg0, %arg0 : tensor<2x!quant.uniform> + func.return %0 : tensor<2x!quant.uniform> +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_callee" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> () + return %arg0 : !stablehlo.token +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_caller" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.call_v1"(%[[ARG0]]) <{callee = #vhlo.string_v1<"type_token_callee">} + // CHECK-SAME: (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = func.call @type_token_callee(%arg0) : (!stablehlo.token) -> !stablehlo.token + return %0 : !stablehlo.token +} + +// CHECK-LABEL: "type_tuple" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_tuple(%arg0: tuple>) -> tuple { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + // CHECK: (!vhlo.tuple_v1>) -> !vhlo.tuple_v1 + } : (tuple>) -> tuple + return %0 : tuple +} + +// ============ DEPENDENCIES ============ + +func.func @composite_target(%arg0: tensor) -> tensor { + return %arg0: tensor +} diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir.bc new file mode 100644 index 0000000000000000000000000000000000000000..5a6474d7e8bc72c9990848b6e986130ebdbf4dbc GIT binary patch literal 19438 zcmch94~$er)^FdcTV3C!)AYM^n$9pO`ZYJhq?tF&&hP;X1k6BaKi~rd;utZ|uVBaAElX})v3lJdcjbCEZ7mONzvsRO9(v@lC!Tudxffn~ z7xt_wfV#5I@3?@e}+M=cULTd7T?lM&f*$j6@8FgF4vx)#J2R)eNBwlz0{nj(M%n|u%*n-;IqCDIt1a5-Q`7WRmr@r^KH*=U;F`{l)n_9ybJ!$AvIbx1>Yf zkKvW62W+2|`idLQSq>&2WFF-FR{b>|kAl}n;&IEF#!|0z`UIcVFMWFIGEe&*RA~>` zK3nRRs!C~ChBJwwlmFt|c1xAd`GX}=uT*)tc==^6-gv|L+WFR*z?_N9>0r(!5;i?y zDN^6?&aqwl{c;a*e&>AQ{L%TF^LJ-DbIx$}w+Od5@`B*}SlAZ->zV)gKPJZGc3cVP zd*{#2U!8wAKRJv!VdhL>&SK_dnX{ZZYnl4d*~gsy%sIrIQ%wCm%~LN|?QIkdBJ~5W zwqr(8q(0+6seh!d^88pMeLwle`lWAl>e>x_`f_z1lLl0WqDGQOEcKIOOof@6z|=&h zCNVXcsVPiNWojBz)0vvV)Jn3}_);`iIX{DYd$)FKvFmRihImZ{}Ttzv2|bAGS7 znDeDt&(ubyx|!-@Y70{(rYcO0GPRSbT}yv-=KQEmFm;ltQ%s#^>I_q7nL5YRd8RHfb&;t{Ont@FH%wh- z&Sa*3VCo8UrZRPvscTGKXNqwZ=4t|06S*@3BjwIa43?`YTutR_8h7Skyxf_OA#-OD zM$J_RS97?U&($KX7IT&5YB_gSVZ2t8-YqXLyvK<&UfLsGmpu0_qpIy2#ZfuD;^VZg_(`d*Bf6?1e+P zy28~}uC8(C035=dgYb!|!ls&Fs)?qWWU9%gnqsP{rkZA|>86@ts+p$hFrCAgeN)Xh z)gse5!qj3@Wlgo*bdJImrdn${$C&Ce)q2x84sV#w3CzFgoPJ$5eYwwa--hO?AL@&SK(C=NwaqOb2s%*mN$yC#E`Ts$-^e5#BJ>2~(Xk z)hSb*Hq{x^!E~QB)j3n0H`N7GT{P7t)A{5;Y>wf2xl5%LpU?Inl79Uu4V{lK36k^ zgLvo=4nks%a1amkh0}#N5za=0if|AKS>corFv1z-YL#$yA!dYw0O%490${yx4sf+m zIEN5C!a2fKpKy*LeuOFs=LA<3;b787g>#ClokHyrYPWFCakWP{=MhE1xqv_tYQJzU zA(n(XC<1xePZ+7CoqlfcOCRz}NC~g*rxlR=Oj0Gie7`T?mp4;*GkRbz@Kdl~`n|+e zI<7LD%eXSsA)yY7)DA-(L8?u2ekztR!B*r!ixQ4DS?H#&|Vj zsMd_JFe5rU1-{Y(23Av-$XK2D0ke26W5gl%Y(}KfmLg%PjFrjf850@f8p{}`8B_wU z3Rf#&UmTvsyDjKRqZ1nnERjy{45N%jcZShMP2meXkqo!u=mnh8HOfL;T6Pn}ZFCmq z8DknDr}h z&ucbc!j}*TWcFPCx-A}hm_K5R)_ZxYEoaYN%3rYMt@tM0kYy zMz@a|{0&>S&0E0RY}xc{gEwhA4X_fA!K2hBtAvy=Kazx4QsUdKqG3}@Q*h#eASYWb z8YTup1<01xnzBWEJM{`<-?VYvvw0(LwB=*YTQCu;FJVU#MneNt1gAj=Ff6krU)D%s zkLfWappCyx>Y<<1r|@L}hqdr8St;77Z6$nF#3mj$_&emiclD^>v#lqdz{uYJ0J51T zGW&5`J_$!W>H6bC_~R+^hb}&{u~Nwfiv|PcWw!YEH~bUwewuoI8YOSHc{4ZK`AdGP z-2ES&*ZDPVTC*)ytl%qdvBk=jd?g0BdNp5di#2Qbng&}mk|V^rb$nf?EpES^-wt2h zaRPA`$}{tx6Dn?KDzMHWE%5jMUoTK3U+ z8x~!Za+0v(=HY%}+-4Zk+||#GDN@{q$unCx4$OuTva+~oNP$T-k&l7NIt&Nv02Zmk zqIFoT4nuGRetJb#m!C%&Ehg>agZIEK`TUMS(H2O0$`Z z2sc8Jtbx>s|KKqWkF^JpSwUoW5Sbf9<^_=jHDrw;7D_XU?8HL`bl(!>>#XtJW-JLJ zOM}R*9&&pqnl+Xo`&BJn5kyu7k<}g|NM;T4WL0LJG_jdP_#{Tq?O2ah=nl_ao#{{_ zYut(LxfVuu)nOazuv{IMufuv=3~r8RjRGQ6cPDgcZVF<>dTeVD8>q+1(rn{a1Uc-_ zA(b_@;VBYgou-c#!%#gsQjc!;Y*i{0%VM*L1;%(!9d>UWc3&NKe;xLKG#ik|`EC{> zBWRu<^!%R=wPuZnc&z5+hh6hq)dlW+XXo>Xxlv?Oc1!e94C2T#0OmBRY0 zTZO+O4@^J?XhiTh?tfmD4f%{yrAh){c%@tr%!z_$2(K#Ld8f0AaC zyV@FPK()a9^G&QqFMsi55}|J)<4&C3HvU>&99Z&y3oiK|gNyp_)Xz=S`|IK#L4lue zoAL!38baj7s2>u(0AsH2=>$ND17+QOjdi`7Y{HPzd!SApPX?361VB3PM5CcL^cdVZ zC;AvV`y=_yw+7ys1gsMW+(@esqBZn&)|kxbfE%Gmuwbr?L}TEa0={`X>hg)uABc}9 zt6T`)+l;9mkA$w!m%b)%(MgHacAGH`JbDW4(C^8V+7wO)qBlAb`XX!0V8IEQS;Kyp zH9G3BIl5PbK5oIYCu^00wr>g<`l3jm@AqniKCka}Q4RTRJ+j#EQzzUk3ufHqz*^lM z%?N!)a>3NH%GYU!J_UxGfk=!SOPeIV*NH$v}cjs10|AE;s2P!85%hk&(HBAESO z;a$Lo8NcEkj7aYhE$h1LJ+dQU;!!ON%-@cVF&^BHjsw^8;GU}LPe&&JYJa->(MeXf zA)V5yksQ4p=`^d~lFk61MPcLbNoQI8rgRPrvx#9oTY6dVTIa`bTmXmmSbOMQJ)eo% zq+Qg~Sfd%Ey}dp34hEj6U!j+@G{!D|4809FxSM_jOb<=W4!xB%zF~Y>z-yN^3d_Xf zvqNt}lm{dJ2VWLG`*qg1!s?EcSAlwA_6AKx&Ejh{vTALdgw1YaM9rU1PK)4&B2eZbZx@bQPYy)jFi*Y1$HNI2Rc%vf_X*>cc@A!u^ z?ols(`~Vqe1YPux`62~679Z#Q-G6$}I^o;q9gI(M{$9;Cc)czX{n4I+)crW^Sj*XR z1}2=SPjm792ipJAyoE<&R)f)C_=h{;nu{pIrNE-?4c{>M4D|RTdLZalU}u5RhIl_V z#0+&z8`n9Hcwdls9z-4FB6Kg5JYJjV3qa4_aGL(UmdpLv&(?bCsqbmI?uoo*gz^LhM!c#fm8+LpTSGR~BV9n+X_a>H z`8;%2)>v=yPXanO>i#gYbm-2k(QWc{ZDf7G^f+$x{A)|C%RA8J(>1Y@ud*|Ad)BC! zyuF5v`cf-H>*@{N=_5vH4Up!V%q}20Sw};wah$kfK?A9h=)Ywc{ zLR)Y%-4npZ-d6Ye*ppA{^tvM8seL~7^wSz!8Nl}Y*t5@SY;^!T;A77}udy`&?4XY! zzA0A(;mAA9jdjjan{M||w%mo;{K06XequfD3WI|A4-AA9XJjolf* zj{De}H5$tVuoFJ^#v2;DD}bH!u{Yn;*oFXh%E#V*TVuHZcG}0@eOF`o0CvX5-hW?X zJpt^jkA3)|#_$ZenlH}z*vB7hY*PR`?_-~QqOoECyWnG=eWtOk0qmlWeg3(|u!~ju zyX0eEe4(*&0Q<_vzWh>S+XC1(KKAw38XF2=mwoKpZ#6a&z<%&CWRALh2f6Z!kBv>? zz^?k(!i8Ecc)h>oW8Z(TvF+9V))+J8b)f!lE7lktl40lb+t{zr(s?!$YCvcQ?urxq z9dhMLt+lzT^)_Q7ur|u3TNquhX&}$g(ZE~SlK|^mR-*dobh6e4i9~1xwigk6%r@1> zx*7fEW)}anmR6?us=TMLVam0kjsN|GuoAIIV>HvZenX& z99`IrndY|k=6FoaTc#A=rEHv=Zf$F9OyiBjww5KWnM|fByeN45?MkZllpR#O&6_c)B$ckJ$~a zjj4I*L=pppo0Lj5C7O~>tTm?KluRs@jM?zB-91w5A1;jD5pY3|ETz(x#FAd?hzj?jO~h zLp}N7;X)-h+7YUhHs!jD!?{g`QlXL`E`!(}>Md(gQdBAQj`TpNBb4tRB5bJ8U+5Xm z?eXsd?`Zo4NxqzCJdfGp@4CC>85a9^RIg^DG==9|z{E`g~d z7#8ORUP$jHG%o*z-q4fc-=<;w1aP#Y}8pElgQN!w z(e)86(t;b%s170C_1IBI_A8XUIiR)Mj|Bo(^&sjH*M)p<5lL}p0o`eGe||tuNKc`t z!@Z~5)mqC85Uw?9X##mwCysz6-T7gx$r3GdEIZ5{dLro>nKDwPP;wVwfWE2X##kN$ z=^&YWAEj_F+4t14PR$Z9Vc7VK7i=Umk?|{`P+z&Mw+uI34d*Keh0rfT!#l8UTpv<) zDii{Owe04xlct=z?-Ylykb8?*J3814+d?-7ONE~SH~lO){4n~T3Ic_XERNPc5^xha7LDmbPfxS9ESj+B|6nEqj`SNHUQSo^=@hBY-y5OA1x?7cMmy~Q5iMK|f0T-lrJ>nr|z zmsK{biClNS2dlUju~q6Tdbajq>u{%~&+D;DDWc|TuaYX`NyQqimgC(Xa=nGVJaVwB zWE{E~``ph!*qU6Fj!^mjO_&kveDyl(q2bMyLLt{z*>NMeK^z0hch`_!B>zgW+dX7> zM}wQ7lurCesGR`d&}zdKtTt~BB>)JXa1}N1NUX4}@ z^(u`{%HA;^Khq_k$}_x++x}i-9+@p<8bpu?}t0y2Aa}P&e zoz6^t)w|=kO*&WUTHwUxws9=O7UQ0L_qSAIs$yWhH<`Bz_R1uw=fs_$WL zl^!vFTbjbhD?5zirg2H(7-L)H&myLuvivwjo#l zxJ`o$(7djcyDN-;K%vpeFjo{*fAEN%R6pTnv7$=Te|tPnk49fl?Hs@c#mR; zhV98KHtTZSqa023pZ2(o`saL_M$d()GE%BNFml`Oqo&$pr6ARR)I>D>bc!l?TIF67 zi+FZK*TrIoSN^M4{?fZA2js&NR{Ao3|FBN9>2J%-MwgNAcv{miaVk9JkGi~ zI+WvSjW<>LSyK(fL)W1jK!FRsH<$X!5T5L|F3RKkzjgPZ(i3ubXe7`GZ&Jy!pipaa<_m-T; z!2c)tXVk8TBl0iuuc%!^ZK3>|#g|g~qb06t!rv{vBqDE>|FFar5N?$}S>lIq6cxrS z{z6o4kYQ$t%OGu#6F_=CDwoNL%(7y!XJb^K#4P-q8Q+9;W3q016~;dTF@?!#%!7Y_qm@(LA?$g3Dw6IHIEGA}A0mDdS~NJit0$S@W6$_aR74c$%T zauN|9lar~)$tn0{xsAz9$*CObr*S#mm79TgWX^^+y}CUnXL8vA@hGXA1O77z$cN>8 z;$Ot&VqTl}1(^2J;Ux=RSwOs5;#kh*Dp%uL=x8N+7q^n-?5hX+`_9tv&0RWg|bA@eX_z4CP@tbemTl(VM0c@COjZ_!h}R~ zjKU-)@ebc^v=!P(g$c#md3(vkeO&IR5^vQVr1B7#hhgiwn8XWiN2#xc@)*W_1mlj# z9$Vcua{5r$T@zIU*Z&rU*YdAb;pRrn8b@#mqE}ycT(4Qi|Pt>b(PC& z(0mx0pTIxO1C>uoW};Xj!=@$Ehr%&A!ITqC@J=%2WP&1cid#%|i)mEMmD5d&C&_U$ zsGX5B-J-)~nnT4?a=zPIM8#^kn2MD$3&}KQ%9N|n=el*!@|avpSVneH&+AROk(y7+ zZYoyDKJZ2H@6pIDrYsRDA}gk49S9$Uy`!ewNwm3g7ie~C6nuD-WDhm&u+i9S%6(AV zvM=1yvL6EbO?kky?DpN^_C4Xfa79KQBn3~)LvC@HjKB*SN2&R!JmzXYPQ|VA1Qi?P zNh%`plv|v3i!)R_B+o)SezA;o9E0!i2E%zxNGg>i0$xqHNIl?HgRd~K=UQ4YW2U@J z!cWK_s8}hl(2$#9J$})DjYwE?W$Pu~8`rjq8=p+CD literal 0 HcmV?d00001 diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir index 55d669cb402..37e378e47dc 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir @@ -2729,6 +2729,30 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f4E2M1FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f4E2M1FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f6E2M3FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f6E2M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f6E3M2FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f6E3M2FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E3M4" // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { @@ -2785,6 +2809,14 @@ func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) func.return %0 : tensor } +// CHECK-LABEL: "type_f8E8M0FNU" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E8M0FNU(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_bf16" // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { diff --git a/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_7_0.mlir b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_7_0.mlir new file mode 100644 index 00000000000..e5a47a84780 --- /dev/null +++ b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_7_0.mlir @@ -0,0 +1,35 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.7.0' --verify-diagnostics --split-input-file %s + +// expected-error @-3 {{failed to convert VHLO to v1.7.0}} +// expected-error @+1 {{failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal}} +func.func @type_f4E2M1FN(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// expected-error @-3 {{failed to convert VHLO to v1.7.0}} +// expected-error @+1 {{failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal}} +func.func @type_f6E2M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// expected-error @-3 {{failed to convert VHLO to v1.7.0}} +// expected-error @+1 {{failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal}} +func.func @type_f6E3M2FN(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// expected-error @-3 {{failed to convert VHLO to v1.7.0}} +// expected-error @+1 {{failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal}} +func.func @type_f8E8M0FNU(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +}