diff --git a/BUILD.bazel b/BUILD.bazel index 58df9f183b6..f435b5532dc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -310,6 +310,21 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "stablehlo_aggressive_simplification_inc_gen", + tbl_outs = [ + ( + ["--gen-rewriters"], + "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td", + deps = [ + ":stablehlo_ops_td_files", + ], +) + gentbl_cc_library( name = "stablehlo_legalize_deprecated_ops_inc_gen", tbl_outs = [ @@ -1127,6 +1142,7 @@ cc_library( ":chlo_ops", ":chlo_rewriters_inc_gen", ":linalg_passes", + ":stablehlo_aggressive_simplification_inc_gen", ":stablehlo_create_compatibility_expander_inc_gen", ":stablehlo_legalize_deprecated_ops_inc_gen", ":stablehlo_ops", diff --git a/stablehlo/testdata/igamma_float64_20_20_float64_20_20_chlo.mlir b/stablehlo/testdata/igamma_float64_20_20_float64_20_20_chlo.mlir index 16e2cfa2e22..7d19428fddd 100644 --- a/stablehlo/testdata/igamma_float64_20_20_float64_20_20_chlo.mlir +++ b/stablehlo/testdata/igamma_float64_20_20_float64_20_20_chlo.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED(timeout in debug builds): stablehlo-opt --chlo-pre-serialization-pipeline -inline %s | stablehlo-translate --interpret +// RUN-DISABLED(#2590, timeout): stablehlo-opt --chlo-pre-serialization-pipeline -inline %s | stablehlo-translate --interpret // RUN: stablehlo-opt --chlo-pre-serialization-pipeline %s | stablehlo-translate --serialize --target=current | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt --chlo-pre-serialization-pipeline %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/testdata/igammac_float64_20_20_float64_20_20_chlo.mlir b/stablehlo/testdata/igammac_float64_20_20_float64_20_20_chlo.mlir index 03e60172d75..42d097dc18f 100644 --- a/stablehlo/testdata/igammac_float64_20_20_float64_20_20_chlo.mlir +++ b/stablehlo/testdata/igammac_float64_20_20_float64_20_20_chlo.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-opt --chlo-pre-serialization-pipeline -inline %s | stablehlo-translate --interpret +// RUN-DISABLED(#2590, timeout): stablehlo-opt --chlo-pre-serialization-pipeline -inline %s | stablehlo-translate --interpret // RUN: stablehlo-opt --chlo-pre-serialization-pipeline %s | stablehlo-translate --serialize --target=current | stablehlo-translate --deserialize | stablehlo-opt > %t.0 // RUN: stablehlo-opt --chlo-pre-serialization-pipeline %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir index 914063131e8..5b21a10d1b2 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir @@ -1,5 +1,143 @@ // RUN: stablehlo-opt --stablehlo-aggressive-folder --split-input-file --verify-diagnostics %s | FileCheck %s +//////// +// AddOp + +// CHECK-LABEL: @add_fold_cst +func.func @add_fold_cst() -> (tensor, tensor) { + %cst = stablehlo.constant dense<1> : tensor + %cst_1 = stablehlo.constant dense<1.0> : tensor + // CHECK: stablehlo.constant dense<2> : tensor + // CHECK: stablehlo.constant dense<2.0{{.*}}> : tensor + %0 = stablehlo.add %cst, %cst : tensor + %1 = stablehlo.add %cst_1, %cst_1 : tensor + return %0, %1 : tensor, tensor +} + +// ----- + +//////// +// BroadcastInDimOp + +// CHECK-LABEL: func.func @broadcast_in_dim_fold_splat +// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>) +func.func @broadcast_in_dim_fold_splat(%arg0: tensor<3x3xi32>) + -> (tensor<6xi32>, tensor<3xf32>, tensor<3x3xi32>) { + %c0 = stablehlo.constant dense<5> : tensor + %c1 = stablehlo.constant dense<3.0> : tensor + %c2 = stablehlo.constant dense<1> : tensor<1x3xi32> + + %0 = stablehlo.broadcast_in_dim %c0, dims = [] : (tensor) -> tensor<6xi32> + %1 = stablehlo.broadcast_in_dim %c1, dims = [] : (tensor) -> tensor<3xf32> + %2 = stablehlo.broadcast_in_dim %c2, dims = [1, 0] : (tensor<1x3xi32>) -> tensor<3x3xi32> + + // CHECK-DAG: [[R0:%.+]] = stablehlo.constant dense<5> : tensor<6xi32> + // CHECK-DAG: [[R1:%.+]] = stablehlo.constant dense<3.000000e+00> : tensor<3xf32> + // CHECK-DAG: [[R2:%.+]] = stablehlo.constant dense<1> : tensor<3x3xi32> + + // CHECK-NEXT: return [[R0]], [[R1]], [[R2]] + return %0, %1, %2 : tensor<6xi32>, tensor<3xf32>, tensor<3x3xi32> +} + +// ----- + +//////// +// CompareOp + +// CHECK-LABEL: func.func @compare_folds +func.func @compare_folds() + -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) { + %cn1 = stablehlo.constant dense<-1> : tensor + %c0 = stablehlo.constant dense<0> : tensor + %c4 = stablehlo.constant dense<4> : tensor + %c5 = stablehlo.constant dense<5> : tensor + + %0 = stablehlo.compare EQ, %cn1, %cn1, SIGNED : (tensor, tensor) -> tensor + %1 = stablehlo.compare GT, %c5, %c5, SIGNED : (tensor, tensor) -> tensor + %2 = stablehlo.compare GE, %c4, %cn1, SIGNED : (tensor, tensor) -> tensor + %3 = stablehlo.compare LE, %c4, %c5, SIGNED : (tensor, tensor) -> tensor + + %4 = stablehlo.compare EQ, %cn1, %cn1, UNSIGNED : (tensor, tensor) -> tensor + %5 = stablehlo.compare GT, %c5, %cn1, UNSIGNED : (tensor, tensor) -> tensor + %6 = stablehlo.compare GE, %c5, %c4, UNSIGNED : (tensor, tensor) -> tensor + %7 = stablehlo.compare LE, %cn1, %c5, UNSIGNED : (tensor, tensor) -> tensor + + // CHECK-DAG: [[FALSE:%.+]] = stablehlo.constant dense : tensor + // CHECK-DAG: [[TRUE:%.+]] = stablehlo.constant dense : tensor + + // CHECK-NEXT: return [[TRUE]], [[FALSE]], [[TRUE]], [[TRUE]], [[TRUE]], [[FALSE]], [[TRUE]], [[FALSE]] + return %0, %1, %2, %3, %4, %5, %6, %7 : + tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor +} + + +// ----- + +//////// +// ConcatenateOp + +// CHECK-LABEL: func.func @concatenate_fold +func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32>) { + %c0 = stablehlo.constant dense<[0, 1]> : tensor<2xi32> + %c1 = stablehlo.constant dense<[2, 3, 4]> : tensor<3xi32> + %c2 = stablehlo.constant dense<[5]> : tensor<1xi32> + + %c3 = stablehlo.constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32> + %c4 = stablehlo.constant dense<[[6, 7, 8]]> : tensor<1x3xi32> + %c5 = stablehlo.constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32> + + %0 = stablehlo.concatenate %c0, %c1, %c2, dim = 0 : (tensor<2xi32>, tensor<3xi32>, tensor<1xi32>) -> tensor<6xi32> + %1 = stablehlo.concatenate %c0, %c2, dim = 0 : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32> + + %2 = stablehlo.concatenate %c3, %c4, dim = 0 : (tensor<2x3xi32>, tensor<1x3xi32>) -> tensor<3x3xi32> + %3 = stablehlo.concatenate %c3, %c5, dim = 1 : (tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x5xi32> + + // CHECK-DAG: [[R0:%.+]] = stablehlo.constant dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32> + // CHECK-DAG: [[R1:%.+]] = stablehlo.constant dense<[0, 1, 5]> : tensor<3xi32> + // CHECK-DAG: [[R2:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\]\]}}> : tensor<3x3xi32> + // CHECK-DAG: [[R3:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2, 11, 12\], \[3, 4, 5, 13, 14\]\]}}> : tensor<2x5xi32> + // CHECK-NEXT: return [[R0]], [[R1]], [[R2]], [[R3]] + return %0, %1, %2, %3 : tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32> +} + +// ----- + +//////// +// MulOp + +// CHECK-LABEL: @mul_fold_cst +func.func @mul_fold_cst() -> (tensor, tensor) { + %cst = stablehlo.constant dense<2> : tensor + %cst_1 = stablehlo.constant dense<2.0> : tensor + // CHECK: stablehlo.constant dense<4> : tensor + // CHECK: stablehlo.constant dense<4.0{{.*}}> : tensor + %0 = stablehlo.multiply %cst, %cst : tensor + %1 = stablehlo.multiply %cst_1, %cst_1 : tensor + return %0, %1 : tensor, tensor +} + +// ----- + +//////// +// SubtractOp + +// CHECK-LABEL: @subtract_fold_cst +func.func @subtract_fold_cst() -> (tensor, tensor) { + %cst = stablehlo.constant dense<1> : tensor + %cst_1 = stablehlo.constant dense<3> : tensor + %cst_2 = stablehlo.constant dense<1.0> : tensor + %cst_3 = stablehlo.constant dense<3.0> : tensor + // CHECK: stablehlo.constant dense<2> : tensor + // CHECK: stablehlo.constant dense<2.0{{.*}}> : tensor + %0 = stablehlo.subtract %cst_1, %cst : tensor + %1 = stablehlo.subtract %cst_3, %cst_2 : tensor + return %0, %1 : tensor, tensor +} + +// ----- + +//////// +// IotaOp // CHECK-LABEL: func @eval_iota func.func @eval_iota() -> (tensor<3x4x5xi32>, tensor<3x4x5xi32>, tensor<3x4x5xi32>) { @@ -41,6 +179,24 @@ func.func @eval_iota_zero_dimension() -> (tensor<0xi32>, tensor<5x0x2xi32>) { // ----- +//////// +// ReshapeOp + +// CHECK-LABEL: func @reshape +func.func @reshape_fold() -> (tensor<1xi32>, tensor<2x2xi32>) { + %c0 = stablehlo.constant dense<2> : tensor + %c1 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + %0 = stablehlo.reshape %c0 : (tensor) -> tensor<1xi32> + %1 = stablehlo.reshape %c1 : (tensor<4xi32>) -> tensor<2x2xi32> + + // CHECK-DAG: [[CST1:%.+]] = stablehlo.constant dense<2> : tensor<1xi32> + // CHECK-DAG: [[CST2:%.+]] = stablehlo.constant dense<{{\[\[1, 2\], \[3, 4\]\]}}> : tensor<2x2xi32> + // CHECK-NEXT: return [[CST1]], [[CST2]] + return %0, %1 : tensor<1xi32>, tensor<2x2xi32> +} + +// ----- + // CHECK-LABEL: func @eval_convert_f32_to_i64 func.func @eval_convert_f32_to_i64() -> tensor<2xi64> { // CHECK-NOT: stablehlo.convert diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir index 9c7dbfa7c72..809c070012d 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir @@ -1,114 +1,124 @@ // RUN: stablehlo-opt --stablehlo-aggressive-simplification --allow-unregistered-dialect --split-input-file %s | FileCheck %s -// CHECK-LABEL: func.func @add -// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor) -func.func @add(%arg0: tensor<2xi32>, %arg1: tensor) - -> (tensor, tensor, tensor, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) { - %c0 = stablehlo.constant dense<0> : tensor - %cn0 = stablehlo.constant dense<-0.0> : tensor - %c0_2 = stablehlo.constant dense<0> : tensor<2xi32> - %c1 = stablehlo.constant dense<5> : tensor - %c2 = stablehlo.constant dense<3.0> : tensor - %c3 = stablehlo.constant dense<[1, 2]> : tensor<2xi32> - - %0 = stablehlo.add %c0, %c1 : tensor - %1 = stablehlo.add %c1, %c1 : tensor - %2 = stablehlo.add %c2, %c2 : tensor - %3 = stablehlo.add %arg1, %cn0 : tensor - - %4 = stablehlo.add %c0_2, %arg0 : tensor<2xi32> - %5 = stablehlo.add %c3, %arg0 : tensor<2xi32> - %6 = stablehlo.add %c3, %c3 : tensor<2xi32> - - // CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<5> : tensor - // CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense<10> : tensor - // CHECK-DAG: [[C2:%.+]] = stablehlo.constant dense<6.000000e+00> : tensor - // CHECK-DAG: [[C3:%.+]] = stablehlo.constant dense<[2, 4]> : tensor<2xi32> - // CHECK-DAG: [[C4:%.+]] = stablehlo.constant dense<[1, 2]> : tensor<2xi32> +///////// +// AddOp + +// CHECK-LABEL: @add_cst_on_rhs +func.func @add_cst_on_rhs(%arg0: tensor) -> tensor { + %cst = stablehlo.constant dense<1.0> : tensor + // CHECK: stablehlo.add %arg0, %cst : tensor + %0 = stablehlo.add %cst, %arg0 : tensor + return %0 : tensor +} - // CHECK-DAG: [[A0:%.+]] = stablehlo.add [[ARG0]], [[C4]] : tensor<2xi32> +// CHECK-LABEL: @add_zero_like_lhs +func.func @add_zero_like_lhs(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.add %0, %arg0 : tensor + // CHECK-NOT: stablehlo.constant + // CHECK: return %arg0 + return %1 : tensor +} - // CHECK-NEXT: return [[C0]], [[C1]], [[C2]], [[ARG1]], [[ARG0]], [[A0]], [[C3]] - return %0, %1, %2, %3, %4, %5, %6 : tensor, tensor, tensor, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32> +// CHECK-LABEL: @add_zero_like_rhs +func.func @add_zero_like_rhs(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<0.0> : tensor + %1 = stablehlo.add %arg0, %0 : tensor + // CHECK-NOT: stablehlo.constant + // CHECK: return %arg0 + return %1 : tensor } // ----- -// CHECK-LABEL: func.func @subtract -// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor) -func.func @subtract(%arg0: tensor<2xi32>, %arg1: tensor) - -> (tensor, tensor, tensor, tensor, tensor<2xi32>, tensor<2xi32>) { - %c0 = stablehlo.constant dense<0> : tensor - %cp0 = stablehlo.constant dense<0.0> : tensor - %c0_2 = stablehlo.constant dense<0> : tensor<2xi32> - %c1 = stablehlo.constant dense<5> : tensor - %c2 = stablehlo.constant dense<3.0> : tensor - %c3 = stablehlo.constant dense<[1, 2]> : tensor<2xi32> - %c4 = stablehlo.constant dense<4> : tensor - %c5 = stablehlo.constant dense<[0, 1]> : tensor<2xi32> +///////// +// AndOp + +// CHECK-LABEL: @and_cst_on_rhs +func.func @and_cst_on_rhs(%arg0: tensor<2xi1>) -> tensor<2xi1> { + %cst = stablehlo.constant dense : tensor<2xi1> + %0 = stablehlo.and %cst, %arg0 : tensor<2xi1> + // Check that constant canonicalized to RHS, then other patterns apply + // CHECK-NOT: stablehlo.and + // return %arg0 + return %0 : tensor<2xi1> +} - %0 = stablehlo.subtract %c1, %c0 : tensor - %1 = stablehlo.subtract %c1, %c4 : tensor +// CHECK-LABEL: @and_zero +func.func @and_zero(%arg0: tensor<2xi1>) -> tensor<2xi1> { + %0 = stablehlo.constant dense : tensor<2xi1> + %1 = stablehlo.and %0, %arg0 : tensor<2xi1> + // CHECK-NOT: stablehlo.and + // CHECK: [[FALSE:%.+]] = stablehlo.constant dense : tensor<2xi1> + // CHECK: return [[FALSE]] + return %1 : tensor<2xi1> +} - %2 = stablehlo.subtract %arg1, %cp0 : tensor - %3 = stablehlo.subtract %arg1, %arg1 : tensor +// CHECK-LABEL: @and_one +func.func @and_one(%arg0: tensor<2xi1>) -> tensor<2xi1> { + %0 = stablehlo.constant dense : tensor<2xi1> + %1 = stablehlo.and %0, %arg0 : tensor<2xi1> + // CHECK-NOT: stablehlo.and + // CHECK: return %arg0 + return %1 : tensor<2xi1> +} - %4 = stablehlo.subtract %arg0, %arg0 : tensor<2xi32> +// ----- - %5 = stablehlo.subtract %c3, %c5 : tensor<2xi32> +///////// +// BroadcastInDim - // CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<5> : tensor - // CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense<1> : tensor - // CHECK-DAG: [[C2:%.+]] = stablehlo.constant dense<0> : tensor<2xi32> - // CHECK-DAG: [[C3:%.+]] = stablehlo.constant dense<1> : tensor<2xi32> +// CHECK-LABEL: func.func @broadcast_in_dim_transpose +// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>) +func.func @broadcast_in_dim_transpose(%arg0: tensor<3x3xi32>) + -> (tensor<3x3xi32>, tensor<3x3xi32>) { + %3 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<3x3xi32>) -> tensor<3x3xi32> + %4 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<3x3xi32>) -> tensor<3x3xi32> - // CHECK-DAG: [[S0:%.+]] = stablehlo.subtract [[ARG1]], [[ARG1]] : tensor + // CHECK: [[R4:%.+]] = stablehlo.transpose [[ARG0]], dims = [1, 0] : (tensor<3x3xi32>) -> tensor<3x3xi32> - // CHECK-NEXT: return [[C0]], [[C1]], [[ARG1]], [[S0]], [[C2]], [[C3]] - return %0, %1, %2, %3, %4, %5 : tensor, tensor, tensor, tensor, tensor<2xi32>, tensor<2xi32> + // CHECK-NEXT: return [[ARG0]], [[R4]] + return %3, %4 : tensor<3x3xi32>, tensor<3x3xi32> } -// ----- - -// CHECK-LABEL: func.func @multiply -// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor) -func.func @multiply(%arg0: tensor<2xi32>, %arg1: tensor) - -> (tensor, tensor, tensor, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) { - %c0 = stablehlo.constant dense<0> : tensor - %cp0 = stablehlo.constant dense<0.0> : tensor - %c0_2 = stablehlo.constant dense<0> : tensor<2xi32> - %c1 = stablehlo.constant dense<5> : tensor - %c2 = stablehlo.constant dense<3.0> : tensor - %c3 = stablehlo.constant dense<[1, 2]> : tensor<2xi32> - %c4 = stablehlo.constant dense<4> : tensor - %c5 = stablehlo.constant dense<1> : tensor<2xi32> - - %0 = stablehlo.multiply %c1, %c0 : tensor - %1 = stablehlo.multiply %c4, %c4 : tensor +// CHECK-LABEL: func.func @broadcast_in_dim_nested +// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>) +func.func @broadcast_in_dim_nested(%arg0: tensor<3x3xi32>) + -> (tensor<3x2x3x3xi32>) { + %6 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<3x3xi32>) -> tensor<3x3x2xi32> + %7 = stablehlo.broadcast_in_dim %6, dims = [0, 2, 1] : (tensor<3x3x2xi32>) -> tensor<3x2x3x3xi32> + // CHECK: [[R6:%.+]] = stablehlo.broadcast_in_dim [[ARG0]], dims = [2, 0] : (tensor<3x3xi32>) -> tensor<3x2x3x3xi32> - %2 = stablehlo.multiply %arg1, %cp0 : tensor - %3 = stablehlo.multiply %c2, %c2 : tensor + // CHECK-NEXT: return [[R6]] + return %7 : tensor<3x2x3x3xi32> +} - %4 = stablehlo.multiply %arg0, %c0_2 : tensor<2xi32> - %5 = stablehlo.multiply %arg0, %c5 : tensor<2xi32> - %6 = stablehlo.multiply %c3, %arg0 : tensor<2xi32> +// CHECK-LABEL: func.func @broadcast_in_dim_reshape +// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x6xi32>) +func.func @broadcast_in_dim_reshape(%arg0: tensor<3x6xi32>) + -> (tensor<1x3x6xi32>, tensor<3x6x1xi32>) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [1, 2] : (tensor<3x6xi32>) -> tensor<1x3x6xi32> + %5 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<3x6xi32>) -> tensor<3x6x1xi32> - // CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<0> : tensor - // CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense<16> : tensor - // CHECK-DAG: [[C2:%.+]] = stablehlo.constant dense<9.000000e+00> : tensor - // CHECK-DAG: [[C3:%.+]] = stablehlo.constant dense<0> : tensor<2xi32> - // CHECK-DAG: [[C4:%.+]] = stablehlo.constant dense<[1, 2]> : tensor<2xi32> - // CHECK-DAG: [[CP0:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: [[R0:%.+]] = stablehlo.reshape [[ARG0]] : (tensor<3x6xi32>) -> tensor<1x3x6xi32> + // CHECK-DAG: [[R5:%.+]] = stablehlo.reshape [[ARG0]] : (tensor<3x6xi32>) -> tensor<3x6x1xi32> - // CHECK-DAG: [[M0:%.+]] = stablehlo.multiply [[ARG1]], [[CP0]] : tensor - // CHECK-DAG: [[M1:%.+]] = stablehlo.multiply [[ARG0]], [[C4]] : tensor<2xi32> + // CHECK-NEXT: return [[R0]], [[R5]] + return %0, %5 : tensor<1x3x6xi32>, tensor<3x6x1xi32> +} - // CHECK-NEXT: return [[C0]], [[C1]], [[M0]], [[C2]], [[C3]], [[ARG0]], [[M1]] - return %0, %1, %2, %3, %4, %5, %6 : tensor, tensor, tensor, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32> +// CHECK-LABEL: func.func @broadcast_in_dim_not_identity_broadcasts +func.func @broadcast_in_dim_not_identity_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { + // CHECK: stablehlo.broadcast_in_dim + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> } // ----- +///////// +// CompareOp + // CHECK-LABEL: func.func @compare_signed_arg // CHECK-SAME: ([[ARG0:%.+]]: tensor) func.func @compare_signed_arg(%arg0: tensor) @@ -141,8 +151,6 @@ func.func @compare_signed_arg(%arg0: tensor) tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor } -// ----- - // CHECK-LABEL: func.func @compare_unsigned_arg // CHECK-SAME: ([[ARG0:%.+]]: tensor) func.func @compare_unsigned_arg(%arg0: tensor) @@ -177,245 +185,8 @@ func.func @compare_unsigned_arg(%arg0: tensor) // ----- -// CHECK-LABEL: func.func @compare_folds -func.func @compare_folds() - -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) { - %cn1 = stablehlo.constant dense<-1> : tensor - %c0 = stablehlo.constant dense<0> : tensor - %c4 = stablehlo.constant dense<4> : tensor - %c5 = stablehlo.constant dense<5> : tensor - - %0 = stablehlo.compare EQ, %cn1, %cn1, SIGNED : (tensor, tensor) -> tensor - %1 = stablehlo.compare GT, %c5, %c5, SIGNED : (tensor, tensor) -> tensor - %2 = stablehlo.compare GE, %c4, %cn1, SIGNED : (tensor, tensor) -> tensor - %3 = stablehlo.compare LE, %c4, %c5, SIGNED : (tensor, tensor) -> tensor - - %4 = stablehlo.compare EQ, %cn1, %cn1, UNSIGNED : (tensor, tensor) -> tensor - %5 = stablehlo.compare GT, %c5, %cn1, UNSIGNED : (tensor, tensor) -> tensor - %6 = stablehlo.compare GE, %c5, %c4, UNSIGNED : (tensor, tensor) -> tensor - %7 = stablehlo.compare LE, %cn1, %c5, UNSIGNED : (tensor, tensor) -> tensor - - // CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense : tensor - // CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense : tensor - - // CHECK-NEXT: return [[C1]], [[C0]], [[C1]], [[C1]], [[C1]], [[C0]], [[C1]], [[C0]] - return %0, %1, %2, %3, %4, %5, %6, %7 : - tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor -} - -// ----- - -// CHECK-LABEL: func.func @select -// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<2xi32>, [[ARGC:%.+]]: tensor<2xi1>) -func.func @select(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %argC: tensor<2xi1>) - -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<4xi32>) { - %c0 = stablehlo.constant dense : tensor - %c1 = stablehlo.constant dense : tensor - - %c0_2 = stablehlo.constant dense : tensor<2xi1> - %c1_2 = stablehlo.constant dense : tensor<2xi1> - - %cond = stablehlo.constant dense<[false, true, false, true]> : tensor<4xi1> - %foo = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32> - %bar = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi32> - - %0 = stablehlo.select %argC, %arg0, %arg0 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %1 = stablehlo.select %c0, %arg0, %arg1 : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %2 = stablehlo.select %c1, %arg0, %arg1 : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %3 = stablehlo.select %c0_2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %4 = stablehlo.select %c1_2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %5 = stablehlo.select %argC, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - - %6 = stablehlo.select %cond, %foo, %bar : (tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - - // CHECK-DAG: [[R0:%.+]] = stablehlo.select [[ARGC]], [[ARG0]], [[ARG1]] - // CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<[5, 2, 7, 4]> : tensor<4xi32> - - // CHECK-NEXT: return [[ARG0]], [[ARG1]], [[ARG0]], [[ARG1]], [[ARG0]], [[R0]], [[C0]] - return %0, %1, %2, %3, %4, %5, %6 : - tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<4xi32> -} - -// ----- - -// CHECK-LABEL: func.func @select_into_minmax1 -// CHECK-SAME: [[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<2xi32>, [[ARG2:%.+]]: tensor<2xi32>, [[ARG3:%.+]]: tensor<2xi32>) -func.func @select_into_minmax1(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, - %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) - -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) { - - %0 = stablehlo.compare EQ, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %1 = stablehlo.compare NE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %2 = stablehlo.compare GE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %3 = stablehlo.compare GT, %arg0, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %4 = stablehlo.compare LE, %arg1, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %5 = stablehlo.compare LT, %arg1, %arg3, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - - %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %s2 = stablehlo.select %2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %s3 = stablehlo.select %3, %arg0, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %s4 = stablehlo.select %4, %arg1, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %s5 = stablehlo.select %5, %arg1, %arg3 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - - // CHECK-DAG: [[C0:%.+]] = stablehlo.compare EQ, [[ARG0]], [[ARG1]], SIGNED - // CHECK-DAG: [[C1:%.+]] = stablehlo.compare NE, [[ARG0]], [[ARG1]], SIGNED - - // CHECK-DAG: [[S0:%.+]] = stablehlo.select [[C0]], [[ARG0]], [[ARG1]] - // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]] - // CHECK-DAG: [[S2:%.+]] = stablehlo.maximum [[ARG0]], [[ARG1]] - // CHECK-DAG: [[S3:%.+]] = stablehlo.maximum [[ARG0]], [[ARG2]] - // CHECK-DAG: [[S4:%.+]] = stablehlo.minimum [[ARG1]], [[ARG2]] - // CHECK-DAG: [[S5:%.+]] = stablehlo.minimum [[ARG1]], [[ARG3]] - - // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]] - return %s0, %s1, %s2, %s3, %s4, %s5 : - tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: func.func @select_into_minmax2 -// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor) -func.func @select_into_minmax2(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) - -> (tensor, tensor, tensor, tensor, - tensor, tensor, tensor, tensor) { - - %0 = stablehlo.compare GT, %arg1, %arg0, SIGNED : (tensor, tensor) -> tensor - %1 = stablehlo.compare GT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor - %2 = stablehlo.compare GE, %arg1, %arg3, SIGNED : (tensor, tensor) -> tensor - %3 = stablehlo.compare GE, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor - - %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor - %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor - %s2 = stablehlo.select %2, %arg3, %arg1 : (tensor, tensor, tensor) -> tensor - %s3 = stablehlo.select %3, %arg0, %arg2 : (tensor, tensor, tensor) -> tensor - - %4 = stablehlo.compare LT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor - %5 = stablehlo.compare LT, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor - %6 = stablehlo.compare LE, %arg2, %arg3, SIGNED : (tensor, tensor) -> tensor - %7 = stablehlo.compare LE, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor - - %s4 = stablehlo.select %4, %arg2, %arg1 : (tensor, tensor, tensor) -> tensor - %s5 = stablehlo.select %5, %arg1, %arg2 : (tensor, tensor, tensor) -> tensor - %s6 = stablehlo.select %6, %arg3, %arg2 : (tensor, tensor, tensor) -> tensor - %s7 = stablehlo.select %7, %arg2, %arg3 : (tensor, tensor, tensor) -> tensor - - // CHECK-DAG: [[C1:%.+]] = stablehlo.compare GT, [[ARG1]], [[ARG2]], SIGNED - // CHECK-DAG: [[C3:%.+]] = stablehlo.compare GE, [[ARG1]], [[ARG2]], SIGNED - - // CHECK-DAG: [[S0:%.+]] = stablehlo.minimum [[ARG0]], [[ARG1]] - // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]] - // CHECK-DAG: [[S2:%.+]] = stablehlo.minimum [[ARG3]], [[ARG1]] - // CHECK-DAG: [[S3:%.+]] = stablehlo.select [[C3]], [[ARG0]], [[ARG2]] - - // CHECK-DAG: [[C5:%.+]] = stablehlo.compare LT, [[ARG0]], [[ARG2]], SIGNED - // CHECK-DAG: [[C7:%.+]] = stablehlo.compare LE, [[ARG0]], [[ARG2]], SIGNED - - // CHECK-DAG: [[S4:%.+]] = stablehlo.maximum [[ARG2]], [[ARG1]] - // CHECK-DAG: [[S5:%.+]] = stablehlo.select [[C5]], [[ARG1]], [[ARG2]] - // CHECK-DAG: [[S6:%.+]] = stablehlo.maximum [[ARG3]], [[ARG2]] - // CHECK-DAG: [[S7:%.+]] = stablehlo.select [[C7]], [[ARG2]], [[ARG3]] - - // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]], [[S6]], [[S7]] - return %s0, %s1, %s2, %s3, %s4, %s5, %s6, %s7 : tensor, tensor, tensor, tensor, - tensor, tensor, tensor, tensor -} - -// ----- - -// CHECK-LABEL: func.func @broadcast_in_dim_splat -// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>) -func.func @broadcast_in_dim_splat(%arg0: tensor<3x3xi32>) - -> (tensor<6xi32>, tensor<3xf32>, tensor<3x3xi32>) { - %c0 = stablehlo.constant dense<5> : tensor - %c1 = stablehlo.constant dense<3.0> : tensor - %c2 = stablehlo.constant dense<1> : tensor<1x3xi32> - - %0 = stablehlo.broadcast_in_dim %c0, dims = [] : (tensor) -> tensor<6xi32> - %1 = stablehlo.broadcast_in_dim %c1, dims = [] : (tensor) -> tensor<3xf32> - %2 = stablehlo.broadcast_in_dim %c2, dims = [1, 0] : (tensor<1x3xi32>) -> tensor<3x3xi32> - - // CHECK-DAG: [[R0:%.+]] = stablehlo.constant dense<5> : tensor<6xi32> - // CHECK-DAG: [[R1:%.+]] = stablehlo.constant dense<3.000000e+00> : tensor<3xf32> - // CHECK-DAG: [[R2:%.+]] = stablehlo.constant dense<1> : tensor<3x3xi32> - - // CHECK-NEXT: return [[R0]], [[R1]], [[R2]] - return %0, %1, %2 : tensor<6xi32>, tensor<3xf32>, tensor<3x3xi32> -} - -// ----- - -// CHECK-LABEL: func.func @broadcast_in_dim_transpose -// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>) -func.func @broadcast_in_dim_transpose(%arg0: tensor<3x3xi32>) - -> (tensor<3x3xi32>, tensor<3x3xi32>) { - %3 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<3x3xi32>) -> tensor<3x3xi32> - %4 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<3x3xi32>) -> tensor<3x3xi32> - - // CHECK: [[R4:%.+]] = stablehlo.transpose [[ARG0]], dims = [1, 0] : (tensor<3x3xi32>) -> tensor<3x3xi32> - - // CHECK-NEXT: return [[ARG0]], [[R4]] - return %3, %4 : tensor<3x3xi32>, tensor<3x3xi32> -} - -// ----- - -// CHECK-LABEL: func.func @broadcast_in_dim_nested -// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>) -func.func @broadcast_in_dim_nested(%arg0: tensor<3x3xi32>) - -> (tensor<3x2x3x3xi32>) { - %6 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<3x3xi32>) -> tensor<3x3x2xi32> - %7 = stablehlo.broadcast_in_dim %6, dims = [0, 2, 1] : (tensor<3x3x2xi32>) -> tensor<3x2x3x3xi32> - // CHECK: [[R6:%.+]] = stablehlo.broadcast_in_dim [[ARG0]], dims = [2, 0] : (tensor<3x3xi32>) -> tensor<3x2x3x3xi32> - - // CHECK-NEXT: return [[R6]] - return %7 : tensor<3x2x3x3xi32> -} - -// ----- - -// CHECK-LABEL: func.func @broadcast_in_dim_reshape -// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x6xi32>) -func.func @broadcast_in_dim_reshape(%arg0: tensor<3x6xi32>) - -> (tensor<1x3x6xi32>, tensor<3x6x1xi32>) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [1, 2] : (tensor<3x6xi32>) -> tensor<1x3x6xi32> - %5 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<3x6xi32>) -> tensor<3x6x1xi32> - - // CHECK-DAG: [[R0:%.+]] = stablehlo.reshape [[ARG0]] : (tensor<3x6xi32>) -> tensor<1x3x6xi32> - // CHECK-DAG: [[R5:%.+]] = stablehlo.reshape [[ARG0]] : (tensor<3x6xi32>) -> tensor<3x6x1xi32> - - // CHECK-NEXT: return [[R0]], [[R5]] - return %0, %5 : tensor<1x3x6xi32>, tensor<3x6x1xi32> -} - -// ----- - -// CHECK-LABEL: func.func @concatenate -func.func @concatenate() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32>) { - %c0 = stablehlo.constant dense<[0, 1]> : tensor<2xi32> - %c1 = stablehlo.constant dense<[2, 3, 4]> : tensor<3xi32> - %c2 = stablehlo.constant dense<[5]> : tensor<1xi32> - - %c3 = stablehlo.constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32> - %c4 = stablehlo.constant dense<[[6, 7, 8]]> : tensor<1x3xi32> - %c5 = stablehlo.constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32> - - %0 = stablehlo.concatenate %c0, %c1, %c2, dim = 0 : (tensor<2xi32>, tensor<3xi32>, tensor<1xi32>) -> tensor<6xi32> - %1 = stablehlo.concatenate %c0, %c2, dim = 0 : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32> - - %2 = stablehlo.concatenate %c3, %c4, dim = 0 : (tensor<2x3xi32>, tensor<1x3xi32>) -> tensor<3x3xi32> - %3 = stablehlo.concatenate %c3, %c5, dim = 1 : (tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x5xi32> - - // CHECK-DAG: [[R0:%.+]] = stablehlo.constant dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32> - // CHECK-DAG: [[R1:%.+]] = stablehlo.constant dense<[0, 1, 5]> : tensor<3xi32> - // CHECK-DAG: [[R2:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\]\]}}> : tensor<3x3xi32> - // CHECK-DAG: [[R3:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2, 11, 12\], \[3, 4, 5, 13, 14\]\]}}> : tensor<2x5xi32> - // CHECK-NEXT: return [[R0]], [[R1]], [[R2]], [[R3]] - return %0, %1, %2, %3 : tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32> -} - -// ----- +///////// +// ConvertOp // CHECK-LABEL: func.func @convert // CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>) @@ -428,6 +199,31 @@ func.func @convert(%arg0: tensor<2xf32>) -> tensor<2xf32> { // ----- +///////// +// DynamicBroadcastInDimOp + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_all_dims_non_expanding +func.func @dynamic_broadcast_in_dim_all_dims_non_expanding(%arg0: tensor, %arg1: tensor<1xindex>) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + // CHECK-NEXT: return %[[ARG]] + %1 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor, tensor<1xindex>) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: @dynamic_broadcast_of_dynamic_reshape_same_shape +func.func @dynamic_broadcast_of_dynamic_reshape_same_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor { + %0 = stablehlo.dynamic_reshape %arg0, %arg1 : (tensor, tensor<2xi64>) -> tensor + %1 = stablehlo.dynamic_broadcast_in_dim %0, %arg1, dims = [0, 1] : (tensor, tensor<2xi64>) -> tensor + + // CHECK-NOT: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.dynamic_reshape %arg0, %arg1 + return %1 : tensor +} + // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { // CHECK: %[[RESULT:.+]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<4xf32>) -> tensor<5x4xf32> @@ -436,8 +232,6 @@ func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32> func.return %0 : tensor<5x4xf32> } -// ----- - // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape(%arg0: tensor) -> tensor<4x32xi32> { %0 = stablehlo.constant dense<[4, 32]> : tensor<2xi32> @@ -448,8 +242,6 @@ func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape(%arg0 func.return %2 : tensor<4x32xi32> } -// ----- - // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape(%arg0: tensor) -> tensor<4x32xf32> { %0 = shape.const_shape [4, 32] : tensor<2xindex> @@ -460,20 +252,16 @@ func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape func.return %2 : tensor<4x32xf32> } -// ----- - // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_requires_cast func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_requires_cast(%arg0: tensor) -> tensor { %0 = shape.const_shape [4, 32] : tensor<2xindex> // CHECK: %[[BCAST:.+]] = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor) -> tensor<4x32xf32> %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [] : (tensor, tensor<2xindex>) -> tensor - // CHECK: %[[RESULT:.*]] = tensor.cast %[[BCAST]] : tensor<4x32xf32> to tensor + // CHECK: %[[RESULT:.*]] = stablehlo.convert %[[BCAST]] : (tensor<4x32xf32>) -> tensor // CHECK: return %[[RESULT]] : tensor func.return %1 : tensor } -// ----- - // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic func.func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { // CHECK: %[[RESULT:.+]] = stablehlo.dynamic_broadcast_in_dim %arg0, %arg1, dims = [1] : (tensor, tensor<2xi64>) -> tensor<5x4xf32> @@ -484,19 +272,8 @@ func.func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic(%arg0: tensor // ----- -// CHECK-LABEL: func @dynamic_broadcast_in_dim_all_dims_non_expanding -func.func @dynamic_broadcast_in_dim_all_dims_non_expanding(%arg0: tensor, %arg1: tensor<1xindex>) -> tensor { - // CHECK-SAME: %[[ARG:.*]]: tensor - // CHECK-NEXT: return %[[ARG]] - %1 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { - broadcast_dimensions = array, - known_expanding_dimensions = array, - known_nonexpanding_dimensions = array - } : (tensor, tensor<1xindex>) -> tensor - func.return %1 : tensor -} - -// ----- +///////// +// DynamicReshapeOp // CHECK-LABEL: func.func @dynamic_reshape // CHECK-SAME: ([[ARG0:%.+]]: tensor<1xf32>, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor<2xi32>) @@ -517,58 +294,8 @@ func.func @dynamic_reshape(%arg0: tensor<1xf32>, %arg1: tensor, %arg2: // ----- -// CHECK-LABEL: func.func @get_tuple_element -// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tuple, tensor>) -func.func @get_tuple_element(%arg0: tensor, %arg1: tensor, %arg2: tuple, tensor>) - -> (tensor, tensor, tensor) { - %t = stablehlo.tuple %arg0, %arg1 : tuple, tensor> - - %a = stablehlo.get_tuple_element %t[0] : (tuple, tensor>) -> tensor - %b = stablehlo.get_tuple_element %t[1] : (tuple, tensor>) -> tensor - - %c = stablehlo.get_tuple_element %arg2[1] : (tuple, tensor>) -> tensor - - // CHECK: [[GTE:%.+]] = stablehlo.get_tuple_element [[ARG2]][1] : (tuple, tensor>) -> tensor - // CHECK-NEXT: return [[ARG0]], [[ARG1]], [[GTE]] - return %a, %b, %c : tensor, tensor, tensor -} - -// ----- - -// CHECK-LABEL: func.func @complex -// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>, [[ARG1:%.+]]: tensor<2xf32>) -func.func @complex(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %c = stablehlo.complex %arg0, %arg1 : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - %r = stablehlo.real %c : (tensor<2xcomplex>) -> (tensor<2xf32>) - %i = stablehlo.imag %c : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return [[ARG0]], [[ARG1]] - return %r, %i : tensor<2xf32>, tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func.func @get_dimension_size -// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xf32>, [[ARG1:%.+]]: tensor) -func.func @get_dimension_size(%arg0: tensor<1x2x3xf32>, %arg1: tensor) - -> (tensor, tensor, tensor, tensor, tensor) { - %a = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<1x2x3xf32>) -> tensor - %b = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<1x2x3xf32>) -> tensor - %c = stablehlo.get_dimension_size %arg0, dim = 2 : (tensor<1x2x3xf32>) -> tensor - - %d = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor) -> tensor - %e = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor) -> tensor - - // CHECK-DAG: [[CST1:%.+]] = stablehlo.constant dense<1> : tensor - // CHECK-DAG: [[CST2:%.+]] = stablehlo.constant dense<2> : tensor - // CHECK-DAG: [[CST3:%.+]] = stablehlo.constant dense<3> : tensor - // CHECK-DAG: [[DYN:%.+]] = stablehlo.get_dimension_size [[ARG1]], dim = 0 : (tensor) -> tensor - // CHECK-NEXT: return [[CST1]], [[CST2]], [[CST3]], [[DYN]], [[CST2]] - return %a, %b, %c, %d, %e : tensor, tensor, tensor, tensor, tensor -} - -// ----- +///////// +// GatherOp // CHECK-LABEL: func.func @gather_to_slice func.func @gather_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<3x6x5xf32> { @@ -587,8 +314,6 @@ func.func @gather_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<3x6x5xf32> { // CHECK-NEXT: return %[[RET]] : tensor<3x6x5xf32> } -// ----- - // CHECK-LABEL: func.func @gather_scalar_index_to_slice func.func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32> { %0 = arith.constant dense<1> : tensor @@ -606,8 +331,6 @@ func.func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x // CHECK-NEXT: return %[[RET]] : tensor<5x6x4xf32> } -// ----- - // CHECK-LABEL: func.func @gather_to_slice_reshape func.func @gather_to_slice_reshape(%arg0: tensor<5x6x7xf32>) -> tensor<3x6xf32> { %0 = arith.constant dense<[1, 2]> : tensor<2xi32> @@ -627,8 +350,6 @@ func.func @gather_to_slice_reshape(%arg0: tensor<5x6x7xf32>) -> tensor<3x6xf32> // CHECK-NEXT: return %[[V1]] : tensor<3x6xf32> } -// ----- - // CHECK-LABEL: func.func @gather_to_slice_indices_clamp_upperbound func.func @gather_to_slice_indices_clamp_upperbound(%arg0 : tensor<4x2xui32>) -> tensor<2xui32> { %0 = arith.constant dense<4> : tensor<1xi32> @@ -647,8 +368,6 @@ func.func @gather_to_slice_indices_clamp_upperbound(%arg0 : tensor<4x2xui32>) -> // CHECK-NEXT: return %[[V1]] : tensor<2xui32> } -// ----- - // CHECK-LABEL: func.func @gather_to_slice_indices_clamp_lowerbound func.func @gather_to_slice_indices_clamp_lowerbound(%arg0 : tensor<4x2xui32>) -> tensor<2xui32> { %0 = arith.constant dense<-1> : tensor<1xi32> @@ -669,87 +388,177 @@ func.func @gather_to_slice_indices_clamp_lowerbound(%arg0 : tensor<4x2xui32>) -> // ----- -// CHECK-LABEL: func.func @reshape -// CHECK-SAME: ([[ARG0:%.+]]: tensor<1xf32>) -func.func @reshape(%arg0: tensor<1xf32>) - -> (tensor<1xf32>, tensor<1xi32>, tensor, tensor<2x2xi32>) { - %c0 = stablehlo.constant dense<2> : tensor - %c1 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32> +///////// +// GetDimensionSizeOp + +// CHECK-LABEL: func.func @get_dimension_size +// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xf32>, [[ARG1:%.+]]: tensor) +func.func @get_dimension_size(%arg0: tensor<1x2x3xf32>, %arg1: tensor) + -> (tensor, tensor, tensor, tensor, tensor) { + %a = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<1x2x3xf32>) -> tensor + %b = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<1x2x3xf32>) -> tensor + %c = stablehlo.get_dimension_size %arg0, dim = 2 : (tensor<1x2x3xf32>) -> tensor + + %d = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor) -> tensor + %e = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor) -> tensor + + // CHECK-DAG: [[CST1:%.+]] = stablehlo.constant dense<1> : tensor + // CHECK-DAG: [[CST2:%.+]] = stablehlo.constant dense<2> : tensor + // CHECK-DAG: [[CST3:%.+]] = stablehlo.constant dense<3> : tensor + // CHECK-DAG: [[DYN:%.+]] = stablehlo.get_dimension_size [[ARG1]], dim = 0 : (tensor) -> tensor + // CHECK-NEXT: return [[CST1]], [[CST2]], [[CST3]], [[DYN]], [[CST2]] + return %a, %b, %c, %d, %e : tensor, tensor, tensor, tensor, tensor +} + +// ----- + +///////// +// GetTupleElementOp + +// CHECK-LABEL: func.func @get_tuple_element +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tuple, tensor>) +func.func @get_tuple_element(%arg0: tensor, %arg1: tensor, %arg2: tuple, tensor>) + -> (tensor, tensor, tensor) { + %t = stablehlo.tuple %arg0, %arg1 : tuple, tensor> + + %a = stablehlo.get_tuple_element %t[0] : (tuple, tensor>) -> tensor + %b = stablehlo.get_tuple_element %t[1] : (tuple, tensor>) -> tensor + + %c = stablehlo.get_tuple_element %arg2[1] : (tuple, tensor>) -> tensor + + // CHECK: [[GTE:%.+]] = stablehlo.get_tuple_element [[ARG2]][1] : (tuple, tensor>) -> tensor + // CHECK-NEXT: return [[ARG0]], [[ARG1]], [[GTE]] + return %a, %b, %c : tensor, tensor, tensor +} + +// ----- + +///////// +// ImagOp / RealOp + +// CHECK-LABEL: func.func @complex +// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>, [[ARG1:%.+]]: tensor<2xf32>) +func.func @complex(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %c = stablehlo.complex %arg0, %arg1 : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + %r = stablehlo.real %c : (tensor<2xcomplex>) -> (tensor<2xf32>) + %i = stablehlo.imag %c : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[ARG0]], [[ARG1]] + return %r, %i : tensor<2xf32>, tensor<2xf32> +} + +// ----- + +///////// +// MaxOp + +// CHECK-LABEL: @maximum_cst_on_rhs +func.func @maximum_cst_on_rhs(%arg0: tensor) -> tensor { + %cst = stablehlo.constant dense<2.0> : tensor + // CHECK: stablehlo.maximum %arg0, %cst : tensor + %0 = stablehlo.maximum %cst, %arg0 : tensor + return %0 : tensor +} + +// ----- + +///////// +// MinOp + +// CHECK-LABEL: @minimum_cst_on_rhs +func.func @minimum_cst_on_rhs(%arg0: tensor) -> tensor { + %cst = stablehlo.constant dense<2.0> : tensor + // CHECK: stablehlo.minimum %arg0, %cst : tensor + %0 = stablehlo.minimum %cst, %arg0 : tensor + return %0 : tensor +} + +// ----- + +///////// +// MulOp - %0 = stablehlo.reshape %arg0 : (tensor<1xf32>) -> tensor<1xf32> - %1 = stablehlo.reshape %c0 : (tensor) -> tensor<1xi32> - %2 = stablehlo.reshape %1 : (tensor<1xi32>) -> tensor - %3 = stablehlo.reshape %c1 : (tensor<4xi32>) -> tensor<2x2xi32> - - // CHECK-DAG: [[CST1:%.+]] = stablehlo.constant dense<2> : tensor - // CHECK-DAG: [[CST2:%.+]] = stablehlo.constant dense<2> : tensor<1xi32> - // CHECK-DAG: [[CST3:%.+]] = stablehlo.constant dense<{{\[\[1, 2\], \[3, 4\]\]}}> : tensor<2x2xi32> - // CHECK-NEXT: return [[ARG0]], [[CST2]], [[CST1]], [[CST3]] - return %0, %1, %2, %3 : tensor<1xf32>, tensor<1xi32>, tensor, tensor<2x2xi32> +// CHECK-LABEL: @multiply_cst_on_rhs +func.func @multiply_cst_on_rhs(%arg0: tensor) -> tensor { + %cst = stablehlo.constant dense<2.0> : tensor + // CHECK: stablehlo.multiply %arg0, %cst : tensor + %0 = stablehlo.multiply %cst, %arg0 : tensor + return %0 : tensor } -// ----- +// CHECK-LABEL: @multiply_by_zero +func.func @multiply_by_zero(%arg0: tensor) -> tensor { + %cst = stablehlo.constant dense<0> : tensor + // CHECK: stablehlo.constant dense<0> : tensor + %0 = stablehlo.multiply %cst, %arg0 : tensor + return %0 : tensor +} -// CHECK-LABEL: @merge_consecutive_reshapes -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]] -func.func @merge_consecutive_reshapes(%arg0: tensor<4x4xi32>) -> tensor<16xi32> { - %0 = stablehlo.reshape %arg0 : (tensor<4x4xi32>) -> tensor<2x8xi32> - %1 = stablehlo.reshape %0 : (tensor<2x8xi32>) -> tensor<16xi32> - // CHECK: [[R0:%.+]] = stablehlo.reshape %[[ARG0]] : (tensor<4x4xi32>) -> tensor<16xi32> - return %1 : tensor<16xi32> +// CHECK-LABEL: @multiply_by_one +func.func @multiply_by_one(%arg0: tensor) -> tensor { + %cst = stablehlo.constant dense<1> : tensor + %0 = stablehlo.multiply %cst, %arg0 : tensor + // CHECK-NOT: stablehlo.constant + // CHECK: return %arg0 : tensor + return %0 : tensor } // ----- -// CHECK-LABEL: func.func @transpose -// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>, [[ARG1:%.+]]: tensor<3x2xf32>, [[ARG2:%.+]]: tensor) -func.func @transpose(%arg0: tensor<2xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor) - -> (tensor<2xf32>, tensor<3x2xf32>, tensor<2x3xf32>, tensor) { - %a = stablehlo.transpose %arg0, dims = [0] : (tensor<2xf32>) -> tensor<2xf32> - %b = stablehlo.transpose %arg1, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> - %c = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32> - %d = stablehlo.transpose %arg2, dims = [] : (tensor) -> tensor - - // CHECK-NEXT: [[X:%.+]] = stablehlo.transpose [[ARG1]], dims = [1, 0] - // CHECK-NEXT: return [[ARG0]], [[ARG1]], [[X]], [[ARG2]] - return %a, %b, %c, %d : tensor<2xf32>, tensor<3x2xf32>, tensor<2x3xf32>, tensor +///////// +// OrOp + +// CHECK-LABEL: @or_cst_on_rhs +func.func @or_cst_on_rhs(%arg0: tensor<2xi1>) -> tensor<2xi1> { + %cst = stablehlo.constant dense : tensor<2xi1> + %0 = stablehlo.or %cst, %arg0 : tensor<2xi1> + // Check that constant canonicalized to RHS, then other patterns apply + // CHECK-NOT: stablehlo.or + // CHECK: return %arg0 + return %0 : tensor<2xi1> } -// ----- +// CHECK-LABEL: @or_zero +func.func @or_zero(%arg0: tensor<2xi1>) -> tensor<2xi1> { + %0 = stablehlo.constant dense : tensor<2xi1> + %1 = stablehlo.or %0, %arg0 : tensor<2xi1> -// CHECK-LABEL: @transpose_is_reshape -func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> { - // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> - %0 = stablehlo.transpose %arg0, dims = [3, 1, 0, 2] : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> - return %0 : tensor<1x4x1x5xf32> + // CHECK-NOT: stablehlo.or + // CHECK: return %arg0 + return %1 : tensor<2xi1> } -// ----- +// CHECK-LABEL: @or_one +func.func @or_one(%arg0: tensor<2xi1>) -> tensor<2xi1> { + %0 = stablehlo.constant dense : tensor<2xi1> + %1 = stablehlo.or %0, %arg0 : tensor<2xi1> -// CHECK-LABEL: @transpose_is_not_reshape -func.func @transpose_is_not_reshape(%arg0: tensor<1x4x5x2xf32>) -> tensor<2x4x1x5xf32> { - // CHECK-NOT: stablehlo.reshape - %0 = stablehlo.transpose %arg0, dims = [3, 1, 0, 2] : (tensor<1x4x5x2xf32>) -> tensor<2x4x1x5xf32> - return %0 : tensor<2x4x1x5xf32> + // CHECK-NOT: stablehlo.or + // CHECK: [[TRUE:%.+]] = stablehlo.constant dense : tensor<2xi1> + // CHECK: return [[TRUE]] + return %1 : tensor<2xi1> } // ----- -// CHECK-LABEL: func.func @reduce_noop_1 -// CHECK-SAME: ([[ARG0:%.+]]: tensor<4x8xf32>) -func.func @reduce_noop_1(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - %0 = stablehlo.constant dense<0.000000e+00> : tensor - %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> - reducer(%arg1: tensor, %arg2: tensor) { - %4 = stablehlo.add %arg1, %arg2 : tensor - stablehlo.return %4 : tensor +///////// +// ReduceOp + +// CHECK-LABEL: @reduce_no_dimensions +func.func @reduce_no_dimensions(%arg0: tensor<8xi64>, %arg1: tensor<8xi64>) -> (tensor<8xi64>, tensor<8xi64>) { + %c = stablehlo.constant dense<1> : tensor + %0:2 = stablehlo.reduce(%arg0 init: %c), (%arg1 init: %c) across dimensions = [] : (tensor<8xi64>, tensor<8xi64>, tensor, tensor) -> (tensor<8xi64>, tensor<8xi64>) + reducer(%arg2: tensor, %arg4: tensor) (%arg3: tensor, %arg5: tensor) { + %1 = stablehlo.add %arg2, %arg4 : tensor + %2 = stablehlo.subtract %arg3, %arg5 : tensor + stablehlo.return %1, %2 : tensor, tensor } - // CHECK: return [[ARG0]] : tensor<4x8xf32> - func.return %1 : tensor<4x8xf32> + // CHECK-NOT: stablehlo.reduce + // CHECK: return %arg0, %arg1 + return %0#0, %0#1 : tensor<8xi64>, tensor<8xi64> } -// ----- - // CHECK-LABEL: func.func @reduce_noop_2 // CHECK-SAME: ([[ARG0:%.+]]: tensor<4x8xi32>, [[ARG1:%.+]]: tensor) func.func @reduce_noop_2(%arg0: tensor<4x8xi32>, %arg1: tensor) -> tensor { @@ -762,27 +571,6 @@ func.func @reduce_noop_2(%arg0: tensor<4x8xi32>, %arg1: tensor) -> tensor } -// ----- - -// CHECK-LABEL: func.func @reduce_zero_ext -func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor { - %0 = stablehlo.constant dense : tensor - %1 = stablehlo.constant dense : tensor<0xi1> - %2 = stablehlo.compare NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> - %3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32> - %4 = stablehlo.constant dense<0> : tensor - %5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor) -> tensor - reducer(%arg1: tensor, %arg2: tensor) { - %6 = stablehlo.add %arg1, %arg2 : tensor - stablehlo.return %6 : tensor - } - - // CHECK: [[CST:%.+]] = stablehlo.constant dense<0> : tensor - // CHECK: return [[CST]] : tensor - return %5 : tensor -} - -// ----- // Each reduce_unused_* test case is accompanied by an ASCII diagram that // represents the surveyed reduce operation in a compact form: @@ -1240,22 +1028,256 @@ func.func @reduce_unused_case10(%arg0: tensor<8xi64>, // ----- +///////// +// ReshapeOp + +// CHECK-LABEL: func @reshape_identity +func.func @reshape_identity(%arg0: tensor<4xf32>) -> (tensor<4xf32>) { + %0 = stablehlo.reshape %arg0 : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NOT: stablehlo.reshape + // CHECK: return %arg0 + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: @reshape_reshape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]] +func.func @reshape_reshape(%arg0: tensor<4x4xi32>) -> tensor<16xi32> { + %0 = stablehlo.reshape %arg0 : (tensor<4x4xi32>) -> tensor<2x8xi32> + %1 = stablehlo.reshape %0 : (tensor<2x8xi32>) -> tensor<16xi32> + // CHECK: [[R0:%.+]] = stablehlo.reshape %[[ARG0]] : (tensor<4x4xi32>) -> tensor<16xi32> + return %1 : tensor<16xi32> +} + +// ----- + +///////// +// SubtractOp + +// CHECK-LABEL: @subtract_same_lhs_rhs +func.func @subtract_same_lhs_rhs(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: stablehlo.constant dense<0> : tensor<2xi32> + %0 = stablehlo.subtract %arg0, %arg0 : tensor<2xi32> + return %0 : tensor<2xi32> +} + +// CHECK-LABEL: @subtract_zero +func.func @subtract_zero(%arg0: tensor<2xi32>, %arg1: tensor<2xf32>) -> (tensor<2xi32>, tensor<2xf32>) { + %0 = stablehlo.constant dense<0> : tensor<2xi32> + %1 = stablehlo.subtract %arg0, %0 : tensor<2xi32> + %2 = stablehlo.constant dense<0.0> : tensor<2xf32> + %3 = stablehlo.subtract %arg1, %2 : tensor<2xf32> + // CHECK-NOT: stablehlo.constant + // CHECK: return %arg0, %arg1 + return %1, %3: tensor<2xi32>, tensor<2xf32> +} + +// ----- + +///////// +// SelectOp + +// CHECK-LABEL: func.func @select +// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<2xi32>, [[ARGC:%.+]]: tensor<2xi1>) +func.func @select(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %argC: tensor<2xi1>) + -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<4xi32>) { + %c0 = stablehlo.constant dense : tensor + %c1 = stablehlo.constant dense : tensor + + %c0_2 = stablehlo.constant dense : tensor<2xi1> + %c1_2 = stablehlo.constant dense : tensor<2xi1> + + %cond = stablehlo.constant dense<[false, true, false, true]> : tensor<4xi1> + %foo = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + %bar = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi32> + + %0 = stablehlo.select %argC, %arg0, %arg0 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = stablehlo.select %c0, %arg0, %arg1 : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %2 = stablehlo.select %c1, %arg0, %arg1 : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %3 = stablehlo.select %c0_2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %4 = stablehlo.select %c1_2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %5 = stablehlo.select %argC, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + + %6 = stablehlo.select %cond, %foo, %bar : (tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + // CHECK-DAG: [[R0:%.+]] = stablehlo.select [[ARGC]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<[5, 2, 7, 4]> : tensor<4xi32> + + // CHECK-NEXT: return [[ARG0]], [[ARG1]], [[ARG0]], [[ARG1]], [[ARG0]], [[R0]], [[C0]] + return %0, %1, %2, %3, %4, %5, %6 : + tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<4xi32> +} + +// CHECK-LABEL: func.func @select_into_minmax1 +// CHECK-SAME: [[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<2xi32>, [[ARG2:%.+]]: tensor<2xi32>, [[ARG3:%.+]]: tensor<2xi32>) +func.func @select_into_minmax1(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, + %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) + -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) { + + %0 = stablehlo.compare EQ, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %1 = stablehlo.compare NE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %2 = stablehlo.compare GE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %3 = stablehlo.compare GT, %arg0, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %4 = stablehlo.compare LE, %arg1, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %5 = stablehlo.compare LT, %arg1, %arg3, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + + %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s2 = stablehlo.select %2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s3 = stablehlo.select %3, %arg0, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s4 = stablehlo.select %4, %arg1, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s5 = stablehlo.select %5, %arg1, %arg3 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + + // CHECK-DAG: [[C0:%.+]] = stablehlo.compare EQ, [[ARG0]], [[ARG1]], SIGNED + // CHECK-DAG: [[C1:%.+]] = stablehlo.compare NE, [[ARG0]], [[ARG1]], SIGNED + + // CHECK-DAG: [[S0:%.+]] = stablehlo.select [[C0]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S2:%.+]] = stablehlo.maximum [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S3:%.+]] = stablehlo.maximum [[ARG0]], [[ARG2]] + // CHECK-DAG: [[S4:%.+]] = stablehlo.minimum [[ARG1]], [[ARG2]] + // CHECK-DAG: [[S5:%.+]] = stablehlo.minimum [[ARG1]], [[ARG3]] + + // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]] + return %s0, %s1, %s2, %s3, %s4, %s5 : + tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32> +} + +// CHECK-LABEL: func.func @select_into_minmax2 +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor) +func.func @select_into_minmax2(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) + -> (tensor, tensor, tensor, tensor, + tensor, tensor, tensor, tensor) { + + %0 = stablehlo.compare GT, %arg1, %arg0, SIGNED : (tensor, tensor) -> tensor + %1 = stablehlo.compare GT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + %2 = stablehlo.compare GE, %arg1, %arg3, SIGNED : (tensor, tensor) -> tensor + %3 = stablehlo.compare GE, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + + %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor + %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor + %s2 = stablehlo.select %2, %arg3, %arg1 : (tensor, tensor, tensor) -> tensor + %s3 = stablehlo.select %3, %arg0, %arg2 : (tensor, tensor, tensor) -> tensor + + %4 = stablehlo.compare LT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + %5 = stablehlo.compare LT, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor + %6 = stablehlo.compare LE, %arg2, %arg3, SIGNED : (tensor, tensor) -> tensor + %7 = stablehlo.compare LE, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor + + %s4 = stablehlo.select %4, %arg2, %arg1 : (tensor, tensor, tensor) -> tensor + %s5 = stablehlo.select %5, %arg1, %arg2 : (tensor, tensor, tensor) -> tensor + %s6 = stablehlo.select %6, %arg3, %arg2 : (tensor, tensor, tensor) -> tensor + %s7 = stablehlo.select %7, %arg2, %arg3 : (tensor, tensor, tensor) -> tensor + + // CHECK-DAG: [[C1:%.+]] = stablehlo.compare GT, [[ARG1]], [[ARG2]], SIGNED + // CHECK-DAG: [[C3:%.+]] = stablehlo.compare GE, [[ARG1]], [[ARG2]], SIGNED + + // CHECK-DAG: [[S0:%.+]] = stablehlo.minimum [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S2:%.+]] = stablehlo.minimum [[ARG3]], [[ARG1]] + // CHECK-DAG: [[S3:%.+]] = stablehlo.select [[C3]], [[ARG0]], [[ARG2]] + + // CHECK-DAG: [[C5:%.+]] = stablehlo.compare LT, [[ARG0]], [[ARG2]], SIGNED + // CHECK-DAG: [[C7:%.+]] = stablehlo.compare LE, [[ARG0]], [[ARG2]], SIGNED + + // CHECK-DAG: [[S4:%.+]] = stablehlo.maximum [[ARG2]], [[ARG1]] + // CHECK-DAG: [[S5:%.+]] = stablehlo.select [[C5]], [[ARG1]], [[ARG2]] + // CHECK-DAG: [[S6:%.+]] = stablehlo.maximum [[ARG3]], [[ARG2]] + // CHECK-DAG: [[S7:%.+]] = stablehlo.select [[C7]], [[ARG2]], [[ARG3]] + + // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]], [[S6]], [[S7]] + return %s0, %s1, %s2, %s3, %s4, %s5, %s6, %s7 : tensor, tensor, tensor, tensor, + tensor, tensor, tensor, tensor +} + +// ----- + +///////// +// TransposeOp + +// CHECK-LABEL: @transpose_identity +func.func @transpose_identity(%arg0: tensor<2xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor) + -> (tensor<2xf32>, tensor<3x2xf32>, tensor<2x3xf32>, tensor) { + %a = stablehlo.transpose %arg0, dims = [0] : (tensor<2xf32>) -> tensor<2xf32> + %b = stablehlo.transpose %arg1, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %c = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32> + %d = stablehlo.transpose %arg2, dims = [] : (tensor) -> tensor + + // CHECK-NEXT: [[X:%.+]] = stablehlo.transpose %arg1, dims = [1, 0] + // CHECK-NEXT: return %arg0, %arg1, [[X]], %arg2 + return %a, %b, %c, %d : tensor<2xf32>, tensor<3x2xf32>, tensor<2x3xf32>, tensor +} + +// CHECK-LABEL: @transpose_is_reshape +func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> { + // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> + %0 = stablehlo.transpose %arg0, dims = [3, 1, 0, 2] : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> + return %0 : tensor<1x4x1x5xf32> +} + +// CHECK-LABEL: @transpose_is_not_reshape +func.func @transpose_is_not_reshape(%arg0: tensor<1x4x5x2xf32>) -> tensor<2x4x1x5xf32> { + // CHECK-NOT: stablehlo.reshape + %0 = stablehlo.transpose %arg0, dims = [3, 1, 0, 2] : (tensor<1x4x5x2xf32>) -> tensor<2x4x1x5xf32> + return %0 : tensor<2x4x1x5xf32> +} + +// ----- + +///////// +// Generic Zero Extent Ops + +// CHECK-LABEL: func.func @reduce_zero_ext +func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor { + %0 = stablehlo.constant dense : tensor + %1 = stablehlo.constant dense : tensor<0xi1> + %2 = stablehlo.compare NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> + %3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32> + %4 = stablehlo.constant dense<0> : tensor + %5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor) -> tensor + reducer(%arg1: tensor, %arg2: tensor) { + %6 = stablehlo.add %arg1, %arg2 : tensor + stablehlo.return %6 : tensor + } + + // CHECK: [[CST:%.+]] = stablehlo.constant dense<0> : tensor + // CHECK: return [[CST]] : tensor + return %5 : tensor +} + +// ----- + +///////// +// XorOp + +// CHECK-LABEL: @xor_cst_on_rhs +func.func @xor_cst_on_rhs(%arg0: tensor<2xi1>) -> tensor<2xi1> { + %cst = stablehlo.constant dense : tensor<2xi1> + %0 = stablehlo.xor %cst, %arg0 : tensor<2xi1> + // CHECK: stablehlo.xor %arg0, %c : tensor<2xi1> + return %0 : tensor<2xi1> +} + +// ----- + +///////// +// Zero Extents + // CHECK-LABEL: func.func @add_zero_ext func.func @add_zero_ext(%arg0 : tensor<5x0xi32>, %arg1 : tensor<5x0xi32>) -> tensor<5x0xi32> { + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<5x0xi32> + // CHECK: return %[[EMPTY]] %0 = stablehlo.add %arg0, %arg1 : tensor<5x0xi32> func.return %0 : tensor<5x0xi32> } -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<5x0xi32> -// CHECK: return %[[EMPTY]] // ----- // CHECK-LABEL: func.func @add_zero_ext_dynamic func.func @add_zero_ext_dynamic(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = stablehlo.add %arg0, %arg1 : tensor + // CHECK-NOT: tensor.empty() func.return %0 : tensor } -// CHECK-NOT: tensor.empty() // ----- @@ -1275,28 +1297,28 @@ func.func @scatter_zero_ext(%arg0 : tensor, %arg1 : tensor<1x0xi32>, %arg2 indices_are_sorted = true, unique_indices = true } : (tensor, tensor<1x0xi32>, tensor<1xf32>) -> tensor + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x0xi32> + // CHECK: %[[SCATTER:.+]] = "stablehlo.scatter"(%arg0, %0, %arg2) + // CHECK: return %[[SCATTER]] func.return %0 : tensor } -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x0xi32> -// CHECK: %[[SCATTER:.+]] = "stablehlo.scatter"(%arg0, %0, %arg2) -// CHECK: return %[[SCATTER]] - // ----- - func.func public @sort_zero_extent(%arg0: tensor<0xi16> {jax.arg_info = "a", mhlo.sharding = "{replicated}"}) -> (tensor<0xi32> {jax.result_info = ""}) { - %0 = stablehlo.iota dim = 0 : tensor<0xi32> - %1:2 = "stablehlo.sort"(%arg0, %0) ({ - ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): - %2 = stablehlo.compare LT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor - stablehlo.return %2 : tensor - }) {dimension = 0 : i64, is_stable = true} : (tensor<0xi16>, tensor<0xi32>) -> (tensor<0xi16>, tensor<0xi32>) - return %1#1 : tensor<0xi32> - } - // CHECK-LABEL: @sort_zero_extent -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<0xi32> -// CHECK: return %[[EMPTY]] +func.func public @sort_zero_extent(%arg0: tensor<0xi16> {jax.arg_info = "a", mhlo.sharding = "{replicated}"}) -> (tensor<0xi32> {jax.result_info = ""}) { + %0 = stablehlo.iota dim = 0 : tensor<0xi32> + %1:2 = "stablehlo.sort"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): + %2 = stablehlo.compare LT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %2 : tensor + }) {dimension = 0 : i64, is_stable = true} : (tensor<0xi16>, tensor<0xi32>) -> (tensor<0xi16>, tensor<0xi32>) + + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<0xi32> + // CHECK: return %[[EMPTY]] + return %1#1 : tensor<0xi32> +} + // ----- @@ -1324,7 +1346,17 @@ func.func public @while_zero_extent(%arg0: tensor, %arg1: tensor<3xf32>, %a // ----- +///////// +// Generic Shape Ops + +// CHECK-LABEL: @push_shape_ops_to_end func.func @push_shape_ops_to_end(%arg0 : tensor<12xf32>) -> tensor<3x4x2x1xf32> { + // CHECK: %[[COS:.+]] = stablehlo.cosine %arg0 : tensor<12xf32> + // CHECK: %[[ABS:.+]] = stablehlo.abs %[[COS]] : tensor<12xf32> + // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %[[ABS]] : (tensor<12xf32>) -> tensor<3x4xf32> + // CHECK: %[[BROADCAST:.+]] = stablehlo.broadcast %[[RESHAPE]], sizes = [1, 2] : (tensor<3x4xf32>) -> tensor<1x2x3x4xf32> + // CHECK: %[[TRANSPOSE:.+]] = stablehlo.transpose %[[BROADCAST]], dims = [2, 3, 1, 0] : (tensor<1x2x3x4xf32>) -> tensor<3x4x2x1xf32> + // CHECK: return %[[TRANSPOSE]] %0 = stablehlo.reshape %arg0 : (tensor<12xf32>) -> tensor<3x4xf32> %1 = stablehlo.broadcast %0, sizes = [1, 2] : (tensor<3x4xf32>) -> tensor<1x2x3x4xf32> %2 = stablehlo.cosine %1 : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> @@ -1333,30 +1365,26 @@ func.func @push_shape_ops_to_end(%arg0 : tensor<12xf32>) -> tensor<3x4x2x1xf32> return %4 : tensor<3x4x2x1xf32> } -// CHECK-LABEL: @push_shape_ops_to_end -// CHECK: %[[COS:.+]] = stablehlo.cosine %arg0 : tensor<12xf32> -// CHECK: %[[ABS:.+]] = stablehlo.abs %[[COS]] : tensor<12xf32> -// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %[[ABS]] : (tensor<12xf32>) -> tensor<3x4xf32> -// CHECK: %[[BROADCAST:.+]] = stablehlo.broadcast %[[RESHAPE]], sizes = [1, 2] : (tensor<3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: %[[TRANSPOSE:.+]] = stablehlo.transpose %[[BROADCAST]], dims = [2, 3, 1, 0] : (tensor<1x2x3x4xf32>) -> tensor<3x4x2x1xf32> -// CHECK: return %[[TRANSPOSE]] // ----- +// CHECK-LABEL: @reorder_with_type_change func.func @reorder_with_type_change(%arg0 : tensor<3x4xi32>) -> tensor<12xi64> { + // CHECK: %[[CONVERT:.+]] = stablehlo.convert %arg0 : (tensor<3x4xi32>) -> tensor<3x4xi64> + // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %[[CONVERT]] : (tensor<3x4xi64>) -> tensor<12xi64> + // CHECK: return %[[RESHAPE]] %0 = stablehlo.reshape %arg0 : (tensor<3x4xi32>) -> tensor<12xi32> %1 = stablehlo.convert %0 : (tensor<12xi32>) -> tensor<12xi64> return %1 : tensor<12xi64> } -// CHECK-LABEL: @reorder_with_type_change -// CHECK: %[[CONVERT:.+]] = stablehlo.convert %arg0 : (tensor<3x4xi32>) -> tensor<3x4xi64> -// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %[[CONVERT]] : (tensor<3x4xi64>) -> tensor<12xi64> -// CHECK: return %[[RESHAPE]] // ----- +// CHECK-LABEL: @do_not_reorder_with_other_uses func.func @do_not_reorder_with_other_uses(%arg0: tensor<2x2xf64>, %arg1: tensor<4xf32>, %arg2: tensor) -> (tensor, tensor<4xf32>) { + // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<2x2xf64>) -> tensor<4xf64> + // CHECK: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<4xf64>) -> tensor<4xf32> %0 = stablehlo.reshape %arg0 : (tensor<2x2xf64>) -> tensor<4xf64> %1 = stablehlo.convert %0 : (tensor<4xf64>) -> tensor<4xf32> %2 = stablehlo.subtract %arg1, %1 : tensor<4xf32> @@ -1368,19 +1396,15 @@ func.func @do_not_reorder_with_other_uses(%arg0: tensor<2x2xf64>, %arg1: tensor< return %3, %2 : tensor, tensor<4xf32> } -// CHECK-LABEL: @do_not_reorder_with_other_uses -// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<2x2xf64>) -> tensor<4xf64> -// CHECK: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<4xf64>) -> tensor<4xf32> // ----- // Make sure we do not crash on unregistered dialects. +// CHECK-LABEL: func.func @generic_op func.func @generic_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: "test_dialect.op" + // CHECK-NEXT: return %0 = "test_dialect.op"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xf32>) return %0 : tensor<2xf32> } - -// CHECK-LABEL: func.func @generic_op -// CHECK-NEXT: "test_dialect.op" -// CHECK-NEXT: return diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index ccdcc26ed03..c69f7ff58cd 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -20,6 +20,10 @@ set(LLVM_TARGET_DEFINITIONS ChloDecompositionPatterns.td) mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters) add_public_tablegen_target(ChloDecompositionPatternsIncGen) +set(LLVM_TARGET_DEFINITIONS StablehloAggressiveSimplificationPatterns.td) +mlir_tablegen(StablehloAggressiveSimplificationPatterns.h.inc --gen-rewriters) +add_public_tablegen_target(StablehloAggressiveSimplificationPatternsIncGen) + set(LLVM_TARGET_DEFINITIONS StablehloCompatibilityExpanderPatterns.td) mlir_tablegen(StablehloCompatibilityExpanderPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloCompatibilityExpanderPatternsIncGen) @@ -57,10 +61,11 @@ add_mlir_dialect_library(StablehloPasses DEPENDS ChloDecompositionPatternsIncGen - StablehloLegalizeDeprecatedOpsPatternsIncGen PassesIncGen - VhloToVersionPatterns + StablehloAggressiveSimplificationPatternsIncGen StablehloCompatibilityExpanderPatternsIncGen + StablehloLegalizeDeprecatedOpsPatternsIncGen + VhloToVersionPatterns LINK_LIBS PUBLIC ChloOps diff --git a/stablehlo/transforms/PassUtils.cpp b/stablehlo/transforms/PassUtils.cpp index e7325613f3a..c54a4ddd7fb 100644 --- a/stablehlo/transforms/PassUtils.cpp +++ b/stablehlo/transforms/PassUtils.cpp @@ -12,6 +12,7 @@ limitations under the License. #include "stablehlo/transforms/PassUtils.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/IR/TypeUtilities.h" @@ -23,11 +24,39 @@ limitations under the License. namespace mlir { namespace stablehlo { +namespace { +// Need some extra handling to generate a DenseElementsAttr from a complex +// scalar, so add a helper function. +DenseElementsAttr getSplatFromScalar(OpBuilder &b, Attribute scalar, + ShapedType type) { + if (auto complexScalar = dyn_cast(scalar)) { + return DenseElementsAttr::get(type, complexScalar.getValue()); + } + return DenseElementsAttr::get(type, scalar); +} +} // namespace + +// Returns `stablehlo::ConstantOp` if value type if static, +// else returns `chlo::ConstantLikeOp`. +Value getConstantLikeImpl(OpBuilder &b, Location loc, Attribute scalar, + Value val) { + if (!llvm::isa(scalar)) + llvm::report_fatal_error("unhandled constant like element type"); + + auto shapedTy = cast(val.getType()); + if (shapedTy.hasStaticShape()) { + Attribute splat = getSplatFromScalar(b, scalar, shapedTy); + return b.create(loc, splat); + } + + return b.create(loc, cast(scalar), + val); +} + Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, Value val) { Type ty = getElementTypeOrSelf(val.getType()); - return b.create(loc, b.getFloatAttr(ty, constant), - val); + return getConstantLikeImpl(b, loc, b.getFloatAttr(ty, constant), val); } bool isAnyQuantizedTypes(TypeRange types) { diff --git a/stablehlo/transforms/PassUtils.h b/stablehlo/transforms/PassUtils.h index f7a448579d4..ea7b9f9c73a 100644 --- a/stablehlo/transforms/PassUtils.h +++ b/stablehlo/transforms/PassUtils.h @@ -15,6 +15,7 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" @@ -24,30 +25,44 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" namespace mlir { namespace stablehlo { -// Add utility functions common across passes. + +// Utility functions common across passes. + +template +Attribute getScalarLike(OpBuilder &b, T constant, Type type) { + Type element = getElementTypeOrSelf(type); + if (isa(element)) return b.getIntegerAttr(element, constant); + if (isa(element)) return b.getFloatAttr(element, constant); + if (auto complexTy = dyn_cast(element)) { + return complex::NumberAttr::get(complexTy, constant, 0); + } + llvm_unreachable("unhandled element type"); +} + +// Creates a constant with using IntegerAttr, FloatAttr, or ComplexAttr stored +// in `scalar`. +// Returns stablehlo::ConstantOp if value type if static, else returns +// chlo::ConstantLikeOp. +Value getConstantLikeImpl(OpBuilder &b, Location loc, Attribute scalar, + Value val); // Creates a chlo::ConstantLikeOp using a splat `constant` of the same shape // as `val`. template Value getConstantLike(OpBuilder &b, Location loc, T constant, Value val) { - Type ty = getElementTypeOrSelf(val.getType()); - auto getAttr = [&]() -> Attribute { - if (isa(ty)) return b.getIntegerAttr(ty, constant); - if (isa(ty)) return b.getFloatAttr(ty, constant); - if (auto complexTy = dyn_cast(ty)) { - return complex::NumberAttr::get(complexTy, constant, 0); - } - llvm_unreachable("unhandled element type"); - }; - return b.create(loc, cast(getAttr()), - val); + auto shapedTy = cast(val.getType()); + Attribute scalar = getScalarLike(b, constant, shapedTy); + return getConstantLikeImpl(b, loc, scalar, val); } // Creates a chlo::ConstantLikeOp using a APFloat splat `constant` of the // same shape as `val`. +// The distinction between double and APFloat causes issues so need this +// explicit template specialization. Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, Value val); diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index 6c61e50926e..43ee22f3550 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -62,6 +62,11 @@ void populateVhloToVersionPatterns(RewritePatternSet *patterns, void populateChloToStablehloPatterns(MLIRContext *context, RewritePatternSet *patterns); +/// CHLO ConstantLikeOp to StableHLO ConstantOp +/// May require dynamic shape broadcasting. +void populateChloConstantLikePattern(MLIRContext *context, + RewritePatternSet *patterns); + /// Collection of folding patterns for StableHLO. void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns, MLIRContext *context, diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index 1ad1b57defc..3044dd0d451 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -40,6 +40,7 @@ def StablehloAggressiveFolderPass : Pass<"stablehlo-aggressive-folder", "func::FuncOp"> { let summary = "Folds StableHLO operations"; let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", "mlir::tensor::TensorDialect", ]; let options = [ @@ -52,6 +53,7 @@ def StablehloAggressiveSimplificationPass : Pass<"stablehlo-aggressive-simplification", "func::FuncOp"> { let summary = "Canonicalizes StableHLO operations"; let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", "mlir::tensor::TensorDialect", ]; } diff --git a/stablehlo/transforms/StablehloAggressiveFolder.cpp b/stablehlo/transforms/StablehloAggressiveFolder.cpp index 2aee5398bc3..a9107514b8b 100644 --- a/stablehlo/transforms/StablehloAggressiveFolder.cpp +++ b/stablehlo/transforms/StablehloAggressiveFolder.cpp @@ -14,6 +14,7 @@ limitations under the License. #include #include +#include #include #include "llvm/ADT/APInt.h" @@ -56,6 +57,11 @@ namespace stablehlo { namespace { +// This is an upper limit on how many elements can be folded by an op folder. +// This limit doesn't apply to some special cases like adding a zero, +// multiplying by one, doing many operations with splats. +constexpr int64_t kFoldOpEltLimit = 65536; + // DenseElementsAttr can be constructed from ArrayRef but not from // ArrayRef. This helper bridges the gap. DenseIntElementsAttr getTensorAttr(ShapedType type, ArrayRef values) { @@ -85,6 +91,26 @@ LogicalResult validateResultTypeForEval(PatternRewriter& rewriter, return success(); } +/// Binary constant folder that used a generic folder function to handle both +/// ints and floats. +template +static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs, + Fn&& folder) { + Attribute operands[2] = {lhs, rhs}; + Type elemTy = getElementTypeOrSelf(lhs); + + Attribute res; + if (isa(elemTy)) + res = constFoldBinaryOp(operands, + folder); + if (isa(elemTy)) + res = constFoldBinaryOp(operands, + folder); + if (res) return cast(res); + + return nullptr; +} + template LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op, @@ -226,7 +252,31 @@ LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op, return success(); } -struct EvalAddOpPattern : public OpRewritePattern { +struct FoldAddOpPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, + PatternRewriter& rewriter) const override { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + // Pattern: add(cst,cst) -> cst + TypedAttr lhsAttr, rhsAttr; + matchPattern(lhs, m_Constant(&lhsAttr)); + matchPattern(rhs, m_Constant(&rhsAttr)); + + if (TypedAttr res; + lhsAttr && rhsAttr && + (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::plus<>{}))) { + rewriter.replaceOpWithNewOp(op, res); + return success(); + } + + return failure(); + } +}; + +struct EvalAddOpShapePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AddOp op, PatternRewriter& rewriter) const override { @@ -249,6 +299,26 @@ struct EvalAndOpPattern : public OpRewritePattern { } }; +// Pattern: broadcast_in_dim(splat, _) -> constant(splat) +struct FoldBroadcastInDimSplatPattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter) const override { + TypedValue operand = op.getOperand(); + + if (SplatElementsAttr cstAttr; + matchPattern(operand, m_Constant(&cstAttr))) { + rewriter.replaceOpWithNewOp( + op, SplatElementsAttr::get(op.getType(), + cstAttr.getSplatValue())); + return success(); + } + return failure(); + } +}; + struct EvalBroadcastInDimOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(BroadcastInDimOp op, @@ -290,7 +360,8 @@ struct EvalCompareOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(CompareOp op, PatternRewriter& rewriter) const override { auto resultType = op.getType(); - return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { + auto kind = op.getCompareType(); + return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) { bool result = false; switch (op.getComparisonDirection()) { case ComparisonDirection::EQ: @@ -300,16 +371,16 @@ struct EvalCompareOpPattern : public OpRewritePattern { result = lhs != rhs; break; case ComparisonDirection::GE: - result = lhs >= rhs; + result = kind == ComparisonType::SIGNED ? lhs.sge(rhs) : lhs.uge(rhs); break; case ComparisonDirection::GT: - result = lhs > rhs; + result = kind == ComparisonType::SIGNED ? lhs.sgt(rhs) : lhs.ugt(rhs); break; case ComparisonDirection::LE: - result = lhs <= rhs; + result = kind == ComparisonType::SIGNED ? lhs.sle(rhs) : lhs.ule(rhs); break; case ComparisonDirection::LT: - result = lhs < rhs; + result = kind == ComparisonType::SIGNED ? lhs.slt(rhs) : lhs.ult(rhs); break; } return getAPSInt(resultType.getElementType(), result); @@ -317,6 +388,52 @@ struct EvalCompareOpPattern : public OpRewritePattern { } }; +////////////////////////////////// +// ConcatenateOp +///////////////////////////////// + +struct FoldConcatenateOpPattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, + PatternRewriter& rewriter) const override { + RankedTensorType type = op.getType(); + if (!type.hasStaticShape()) return failure(); + + size_t numElems = type.getNumElements(); + if (numElems > kFoldOpEltLimit) return failure(); + + // Fold concatenate when all inputs are constants. + OperandRange inputs = op.getInputs(); + SmallVector constants(inputs.size()); + for (auto [input, constant] : llvm::zip_equal(inputs, constants)) { + if (!matchPattern(input, m_Constant(&constant))) return failure(); + } + + uint64_t dim = op.getDimension(); + ArrayRef shape = type.getShape(); + int64_t topSize = std::accumulate(shape.begin(), shape.begin() + dim, + int64_t{1}, std::multiplies<>{}); + + SmallVector newElems; + newElems.reserve(numElems); + + for (int64_t i = 0; i != topSize; ++i) { + for (ElementsAttr attr : constants) { + size_t bottomSize = attr.getNumElements() / topSize; + auto begin = attr.value_begin() + (i * bottomSize); + newElems.append(begin, begin + bottomSize); + } + } + + assert(newElems.size() == numElems); + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(op.getType(), newElems)); + return success(); + } +}; + struct EvalConcatenateOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcatenateOp op, @@ -426,6 +543,40 @@ struct EvalMinOpPattern : public OpRewritePattern { } }; +struct FoldMulOpPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, + PatternRewriter& rewriter) const override { + auto elemType = op.getType().getElementType(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + TypedAttr lhsAttr; + matchPattern(lhs, m_Constant(&lhsAttr)); + + TypedAttr rhsAttr; + matchPattern(rhs, m_Constant(&rhsAttr)); + + // The canonical form has the constant operand as the RHS. + if (isa(elemType) && lhsAttr && !rhsAttr) { + rewriter.modifyOpInPlace(op, [op, lhs, rhs] { + op->setOperands(ValueRange{rhs, lhs}); + }); + return success(); + } + + if (TypedAttr res; + lhsAttr && rhsAttr && + (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) { + rewriter.replaceOpWithNewOp(op, res); + return success(); + } + + return failure(); + } +}; + struct EvalMulOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MulOp op, @@ -466,6 +617,7 @@ struct EvalReshapeOpPattern : public OpRewritePattern { if (failed(validateResultTypeForEval(rewriter, op, resultType))) return failure(); + // Pattern: reshape(cst, shape) -> cst DenseIntElementsAttr attr; if (!matchPattern(op.getOperand(), m_Constant(&attr))) return rewriter.notifyMatchFailure(op, "expected constant operand"); @@ -589,6 +741,30 @@ struct EvalSliceOpPattern : public OpRewritePattern { } }; +struct FoldSubtractOpPattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, + PatternRewriter& rewriter) const override { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + TypedAttr lhsAttr, rhsAttr; + matchPattern(lhs, m_Constant(&lhsAttr)); + matchPattern(rhs, m_Constant(&rhsAttr)); + + if (TypedAttr res; + lhsAttr && rhsAttr && + (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::minus<>{}))) { + rewriter.replaceOpWithNewOp(op, res); + return success(); + } + + return failure(); + } +}; + struct EvalSubtractOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SubtractOp op, @@ -727,12 +903,19 @@ void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns, populateStablehloShapeFolderPatterns(patterns, context, foldFloat); patterns->add(context); patterns->add(context); + + // TODO: Consolidate FoldOp patterns + // One is used by Shape Refinement, the other is a generic folder. + patterns + ->add( + context); } void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns, MLIRContext* context, bool foldFloat) { - patterns->add(context); + patterns->add(context); patterns->add(context); patterns->add(context); patterns->add(context); diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/StablehloAggressiveSimplification.cpp index 6445f72d565..4ccdd7ded2d 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/StablehloAggressiveSimplification.cpp @@ -47,6 +47,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/Base.h" #include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/PassUtils.h" #include "stablehlo/transforms/Passes.h" using llvm::SmallBitVector; @@ -58,8 +59,9 @@ namespace stablehlo { #include "stablehlo/transforms/Passes.h.inc" namespace { -// This is an upper limit on how many elements canonicalization patterns are -// allowed to materialize as new constants. +// This is an upper limit on how many elements can be folded by an op folder. +// This limit doesn't apply to some special cases like adding a zero, +// multiplying by one, doing many operations with splats. constexpr int64_t kFoldOpEltLimit = 65536; static bool isIotaRange(ArrayRef dims) { @@ -82,158 +84,25 @@ struct m_AnyOf { template m_AnyOf(MatcherA, MatcherB) -> m_AnyOf; -/// Binary constant folder that used a generic folder function to handle both -/// ints and floats. -template -static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs, - Fn &&folder) { - Attribute operands[2] = {lhs, rhs}; - Type elemTy = getElementTypeOrSelf(lhs); - - Attribute res; - if (isa(elemTy)) - res = constFoldBinaryOp(operands, - folder); - if (isa(elemTy)) - res = constFoldBinaryOp(operands, - folder); - if (res) return cast(res); - - return nullptr; -} - -struct AddOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, - PatternRewriter &rewriter) const override { - auto elemType = op.getType().getElementType(); - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - if (matchPattern(lhs, m_Zero())) { - rewriter.replaceOp(op, rhs); - return success(); - } - - if (matchPattern(rhs, m_AnyOf(m_Zero(), m_NegZeroFloat()))) { - rewriter.replaceOp(op, lhs); - return success(); - } - - TypedAttr lhsAttr; - matchPattern(lhs, m_Constant(&lhsAttr)); - - TypedAttr rhsAttr; - matchPattern(rhs, m_Constant(&rhsAttr)); - - // The canonical form has the constant operand as the RHS. - if (isa(elemType) && lhsAttr && !rhsAttr) { - rewriter.modifyOpInPlace(op, [op, lhs, rhs] { - op->setOperands(ValueRange{rhs, lhs}); - }); - return success(); - } - - if (TypedAttr res; - lhsAttr && rhsAttr && - (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::plus<>{}))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } +/// Matches when either of the submatchers match. +template +struct m_AnyAttrOf { + m_AnyAttrOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {} - return failure(); + bool match(Attribute attr) { + return matcherA.match(attr) || matcherB.match(attr); } -}; - -struct SubtractOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, - PatternRewriter &rewriter) const override { - auto elemType = op.getType().getElementType(); - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - if (isa(elemType) && lhs == rhs) { - rewriter.replaceOpWithNewOp( - op, rewriter.getZeroAttr(op.getType())); - return success(); - } - - // Subtraction of 0. - if (matchPattern(rhs, m_AnyOf(m_Zero(), m_PosZeroFloat()))) { - rewriter.replaceOp(op, lhs); - return success(); - } - - TypedAttr lhsAttr; - matchPattern(lhs, m_Constant(&lhsAttr)); - - TypedAttr rhsAttr; - matchPattern(rhs, m_Constant(&rhsAttr)); - - if (TypedAttr res; - lhsAttr && rhsAttr && - (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::minus<>{}))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } - - return failure(); - } + MatcherA matcherA; + MatcherB matcherB; }; -struct MulOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, - PatternRewriter &rewriter) const override { - auto elemType = op.getType().getElementType(); - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - // Multiplication by 0. This fold is not trivial for floats in presence of - // NaN values. - if (matchPattern(lhs, m_Zero())) { - rewriter.replaceOp(op, lhs); - return success(); - } - if (matchPattern(rhs, m_Zero())) { - rewriter.replaceOp(op, rhs); - return success(); - } - - // Multiplication by 1. - if (matchPattern(rhs, m_One())) { - rewriter.replaceOp(op, lhs); - return success(); - } - - TypedAttr lhsAttr; - matchPattern(lhs, m_Constant(&lhsAttr)); - - TypedAttr rhsAttr; - matchPattern(rhs, m_Constant(&rhsAttr)); - - // The canonical form has the constant operand as the RHS. - if (isa(elemType) && lhsAttr && !rhsAttr) { - rewriter.modifyOpInPlace(op, [op, lhs, rhs] { - op->setOperands(ValueRange{rhs, lhs}); - }); - return success(); - } - - if (TypedAttr res; - lhsAttr && rhsAttr && - (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } +template +m_AnyAttrOf(MatcherA, MatcherB) -> m_AnyAttrOf; - return failure(); - } -}; +////////////////////////////////// +// CompareOp +///////////////////////////////// static mlir::stablehlo::ComparisonDirection invertDirection( mlir::stablehlo::ComparisonDirection direction) { @@ -256,41 +125,6 @@ static mlir::stablehlo::ComparisonDirection invertDirection( llvm::report_fatal_error("Unhandled case"); } -static APInt calculateComp(mlir::stablehlo::ComparisonType kind, - mlir::stablehlo::ComparisonDirection direction, - const APInt &lhs, const APInt &rhs) { - using mlir::stablehlo::ComparisonDirection; - using mlir::stablehlo::ComparisonType; - assert(llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED}, - kind) && - "Not an integer comparison"); - - auto asBit = [](bool value) { - return value ? APInt::getAllOnes(1) : APInt::getZero(1); - }; - - switch (direction) { - case ComparisonDirection::EQ: - return asBit(lhs == rhs); - case ComparisonDirection::NE: - return asBit(lhs != rhs); - case ComparisonDirection::GE: - return asBit(kind == ComparisonType::SIGNED ? lhs.sge(rhs) - : lhs.uge(rhs)); - case ComparisonDirection::GT: - return asBit(kind == ComparisonType::SIGNED ? lhs.sgt(rhs) - : lhs.ugt(rhs)); - case ComparisonDirection::LE: - return asBit(kind == ComparisonType::SIGNED ? lhs.sle(rhs) - : lhs.ule(rhs)); - case ComparisonDirection::LT: - return asBit(kind == ComparisonType::SIGNED ? lhs.slt(rhs) - : lhs.ult(rhs)); - } - - llvm_unreachable("Unhandled case"); -} - struct CompareOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -313,6 +147,8 @@ struct CompareOpCanon final : OpRewritePattern { Value lhs = op.getLhs(); Value rhs = op.getRhs(); + // Pattern: compare(X, X, [EQ,GE,LE]) -> true + // Pattern: compare(X, X, [NE,GT,LT]) -> false if (lhs == rhs) { switch (direction) { case ComparisonDirection::EQ: @@ -333,10 +169,10 @@ struct CompareOpCanon final : OpRewritePattern { llvm_unreachable("Unhandled case"); } - TypedAttr lhsAttr; + // Pattern: compare(cst, X, comparator) -> compare(X, cst, + // inverse(comparator)) + TypedAttr lhsAttr, rhsAttr; matchPattern(lhs, m_Constant(&lhsAttr)); - - TypedAttr rhsAttr; matchPattern(rhs, m_Constant(&rhsAttr)); // The canonical form has the constant operand as the RHS. @@ -348,21 +184,14 @@ struct CompareOpCanon final : OpRewritePattern { return success(); } - if (Attribute res; - lhsAttr && rhsAttr && - (res = constFoldBinaryOp( - ArrayRef({lhsAttr, rhsAttr}), op.getType(), - [direction, kind = *compType](const APInt &a, const APInt &b) { - return calculateComp(kind, direction, a, b); - }))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } - return failure(); } }; +////////////////////////////////// +// SelectOp +///////////////////////////////// + struct SelectOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -463,118 +292,22 @@ struct CompareSelectIntoMinMax final } }; -struct BroadcastInDimOpCanon final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, - PatternRewriter &rewriter) const override { - RankedTensorType type = op.getType(); - - TypedValue operand = op.getOperand(); - RankedTensorType operandTy = operand.getType(); - - // Fold when broadcast is a noop. - auto dims = op.getBroadcastDimensions(); - if (type == operandTy && isIotaRange(dims)) { - rewriter.replaceOp(op, operand); - return success(); - } - - // Handle splat broadcasts. - if (SplatElementsAttr cstAttr; - matchPattern(operand, m_Constant(&cstAttr))) { - rewriter.replaceOpWithNewOp( - op, SplatElementsAttr::get(op.getType(), - cstAttr.getSplatValue())); - return success(); - } - - if (operandTy.hasStaticShape() && type.hasStaticShape() && - type.getNumElements() == operandTy.getNumElements()) { - // BroadcastInDim equivalent to reshape. - if (llvm::is_sorted(dims)) { - rewriter.replaceOpWithNewOp(op, type, - operand); - return success(); - } - // BroadcastInDim equivalent to transpose. - if (type.getRank() == operandTy.getRank()) { - rewriter.replaceOpWithNewOp( - op, type, operand, dims); - return success(); - } - } - - // Eliminate redundant nested BroadcastInDim. - if (auto definingOp = - operand.getDefiningOp()) { - auto newIndices = - llvm::map_to_vector(definingOp.getBroadcastDimensions(), - [&dims](int64_t dim) { return dims[dim]; }); - rewriter.replaceOpWithNewOp( - op, type, definingOp.getOperand(), newIndices); - return success(); - } - - return failure(); - } -}; - -struct ConcatenateOpCanon final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, - PatternRewriter &rewriter) const override { - RankedTensorType type = op.getType(); - if (!type.hasStaticShape()) return failure(); - - size_t numElems = type.getNumElements(); - if (numElems > kFoldOpEltLimit) return failure(); - - // Fold concatenate when all inputs are constants. - OperandRange inputs = op.getInputs(); - SmallVector constants(inputs.size()); - for (auto [input, constant] : llvm::zip_equal(inputs, constants)) { - if (!matchPattern(input, m_Constant(&constant))) return failure(); - } - - uint64_t dim = op.getDimension(); - ArrayRef shape = type.getShape(); - int64_t topSize = std::accumulate(shape.begin(), shape.begin() + dim, - int64_t{1}, std::multiplies<>{}); - - SmallVector newElems; - newElems.reserve(numElems); - - for (int64_t i = 0; i != topSize; ++i) { - for (ElementsAttr attr : constants) { - size_t bottomSize = attr.getNumElements() / topSize; - auto begin = attr.value_begin() + (i * bottomSize); - newElems.append(begin, begin + bottomSize); - } - } - - assert(newElems.size() == numElems); - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(op.getType(), newElems)); - return success(); - } -}; - -struct ConvertOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::ConvertOp op, - PatternRewriter &rewriter) const override { - // Check if this convert is a noop. - if (op.getOperand().getType() != op.getType()) return failure(); +////////////////////////////////// +// BroadcastInDimOp +///////////////////////////////// + +// Used in DRR file. +DenseI64ArrayAttr getMergedBroadcastDimensions(OpBuilder &b, + ArrayRef dims, + ArrayRef dimsParent) { + auto mergedDims = llvm::map_to_vector( + dimsParent, [&dims](int64_t dim) { return dims[dim]; }); + return b.getDenseI64ArrayAttr(mergedDims); +} - rewriter.replaceOp(op, op.getOperand()); - return success(); - } -}; +////////////////////////////////// +// DynamicBroadcastInDimOp +///////////////////////////////// /// Does the same as PatternRewriter::replaceOpWithNewOp, but with a twist. /// @@ -585,7 +318,7 @@ struct ConvertOpCanon final : OpRewritePattern { /// Oftentimes, this works just fine because HLO is designed to accommodate /// this kind of type refinements. But sometimes, this doesn't work - when /// the op is used outside of the HLO dialect (e.g. in func.return). In these -/// cases, we insert a tensor.cast to smooth things out. +/// cases, we insert a stablehlo.convert to smooth things out. template static OpTy refineOpWithNewOp(PatternRewriter &rewriter, Operation *op, Args &&...args) { @@ -600,7 +333,7 @@ static OpTy refineOpWithNewOp(PatternRewriter &rewriter, Operation *op, if (llvm::any_of(opResult.getUsers(), [&](Operation *user) { return user->getDialect() != op->getDialect(); })) - replacementResult = rewriter.create( + replacementResult = rewriter.create( op->getLoc(), opResult.getType(), newOpResult); replacementResults.push_back(replacementResult); } @@ -643,66 +376,37 @@ struct DynamicBroadcastInDimOpNotActuallyDynamic final } }; -struct ChainedDynamicBroadcastInDimCanonicalization final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +////////////////////////////////// +// DynamicReshapeOp +///////////////////////////////// - LogicalResult matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp bcast, - PatternRewriter &rewriter) const override { - auto precedingBcast = - bcast.getOperand() - .getDefiningOp(); - if (!precedingBcast) return failure(); - - // Compose broadcast dimensions. - SmallVector composition; - for (int64_t precedingDim : precedingBcast.getBroadcastDimensions()) { - composition.push_back(bcast.getBroadcastDimensions()[precedingDim]); - } - auto composedBcastDims = rewriter.getDenseI64ArrayAttr(composition); - - rewriter.replaceOpWithNewOp( - bcast, bcast.getType(), precedingBcast.getOperand(), - bcast.getOutputDimensions(), composedBcastDims); - return success(); - } -}; - -// If all dimensions are known to be nonexpanding from the attribute, replace -// the dynamic broadcast with a cast. -struct DynamicBroadcastInDimAllDimsNonExpanding final - : OpRewritePattern { +struct DynamicReshapeOpCanon final + : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp op, + LogicalResult matchAndRewrite(mlir::stablehlo::DynamicReshapeOp op, PatternRewriter &rewriter) const override { + // This is a noop when the output type is already a static shape. RankedTensorType type = op.getType(); + if (!type.hasStaticShape()) return failure(); - if (!op.getKnownNonexpandingDimensions() || - static_cast(op.getKnownNonexpandingDimensions()->size()) != - type.getRank()) { - return rewriter.notifyMatchFailure( - op, "known_nonexpanding_dimensions don't cover all output dims"); - } - - auto cast = rewriter.createOrFold(op.getLoc(), type, - op.getOperand()); - rewriter.replaceOp(op, cast); + rewriter.replaceOpWithNewOp(op, type, + op.getOperand()); return success(); } }; -struct NoopReduceOpCanon final : OpRewritePattern { +////////////////////////////////// +// ReduceOp +///////////////////////////////// + +// Pattern: reduce[A](_, _, fn:return A) -> A... +struct ReduceNoopVariableReturn final + : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op, PatternRewriter &rewriter) const override { - // No dimensions to reduce. - if (op.getDimensions().empty()) { - rewriter.replaceOp(op, op.getInputs()); - return success(); - } - // If all returned values in the ReduceOp region exists outside the // region, replace the ReduceOp with those values. if (auto retOp = dyn_cast( @@ -721,6 +425,7 @@ struct NoopReduceOpCanon final : OpRewritePattern { } }; +// Pattern: reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...] struct EmptyReduceOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -761,6 +466,7 @@ struct EmptyReduceOpCanon final : OpRewritePattern { } }; +// Pattern: reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)] struct UnusedResultReduceOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -872,63 +578,13 @@ struct UnusedResultReduceOpCanon final } }; -struct DynamicReshapeOpCanon final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::DynamicReshapeOp op, - PatternRewriter &rewriter) const override { - // This is a noop when the output type is already a static shape. - RankedTensorType type = op.getType(); - if (!type.hasStaticShape()) return failure(); - - rewriter.replaceOpWithNewOp(op, type, - op.getOperand()); - return success(); - } -}; - -struct GetTupleElementOpCanon final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::GetTupleElementOp op, - PatternRewriter &rewriter) const override { - auto tuple = op.getOperand().getDefiningOp(); - if (!tuple) return failure(); - - Value result = tuple.getOperand(op.getIndex()); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct RealOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::RealOp op, - PatternRewriter &rewriter) const override { - auto complex = op.getOperand().getDefiningOp(); - if (!complex) return failure(); - - rewriter.replaceOp(op, complex.getLhs()); - return success(); - } -}; - -struct ImagOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::ImagOp op, - PatternRewriter &rewriter) const override { - auto complex = op.getOperand().getDefiningOp(); - if (!complex) return failure(); - - rewriter.replaceOp(op, complex.getRhs()); - return success(); - } -}; +///////////////////////////////// +// GetDimensionSizeOp +///////////////////////////////// +// TODO: This is duplicated with a pattern in shape refinement, consider +// consolidating. +// Pattern: get_dimension_size(X, i) -> X.shape[i] struct GetDimensionSizeOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -949,8 +605,13 @@ struct GetDimensionSizeOpCanon final } }; +////////////////////////////////// +// GatherOp +///////////////////////////////// + /// Converts gather ops to slice ops in case we have a single set of constant /// indices. +// Pattern: gather(X, cst_start_indices) -> slice(X, slice_start, slice_end) struct GatherOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1020,60 +681,11 @@ struct GatherOpCanon final : OpRewritePattern { } }; -struct ReshapeOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op, - PatternRewriter &rewriter) const override { - // Fold noop reshape. - if (op.getType() == op.getOperand().getType()) { - rewriter.replaceOp(op, op.getOperand()); - return success(); - } - - // Fold reshape of a constant. - ElementsAttr cstAttr; - if (!matchPattern(op.getOperand(), m_Constant(&cstAttr))) return failure(); - - if (auto splat = dyn_cast(cstAttr)) { - rewriter.replaceOpWithNewOp( - op, SplatElementsAttr::get(op.getType(), - splat.getSplatValue())); - return success(); - } - - auto elements = - llvm::to_vector_of(cstAttr.getValues()); - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(op.getType(), elements)); - return success(); - } -}; - -struct MergeConsecutiveReshapes final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op, - PatternRewriter &rewriter) const override { - // Fold noop reshape. - auto operand = op.getOperand(); - if (op.getType() == operand.getType()) { - rewriter.replaceOp(op, op.getOperand()); - return success(); - } - - // Fold reshape(reshape(x)). - auto reshapeOp = operand.getDefiningOp(); - if (!reshapeOp) - return rewriter.notifyMatchFailure( - op, "requires defining op of operand to be Reshape"); - - op.setOperand(reshapeOp->getOperand(0)); - return success(); - } -}; +////////////////////////////////// +// TransposeOp +///////////////////////////////// +// Pattern: transpose(X, [no_mem_layout_change...]) -> reshape(X) struct TransposeIsReshape final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1083,11 +695,6 @@ struct TransposeIsReshape final auto input = op.getOperand(); auto permutation = op.getPermutation(); - if (isIotaRange(permutation)) { - rewriter.replaceOp(op, op.getOperand()); - return success(); - } - RankedTensorType inputTy = input.getType(); if (!inputTy.hasStaticShape() || !op.getType().hasStaticShape()) return rewriter.notifyMatchFailure( @@ -1095,15 +702,14 @@ struct TransposeIsReshape final "requires input and output to be of a statically-shaped ranked " "tensor type"); - SmallVector permValues(permutation); + // Check that the permutation is a valid memory layout change. + // All non-zero/one dimensions must be in increasing order. SmallVector nonZeroPerms; - nonZeroPerms.reserve(permValues.size()); - for (auto idx : permValues) { - auto sz = inputTy.getDimSize(idx); - if (sz != 1) nonZeroPerms.push_back(idx); - } + nonZeroPerms.reserve(permutation.size()); + for (auto idx : permutation) + if (inputTy.getDimSize(idx) != 1) nonZeroPerms.push_back(idx); - for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) + for (size_t i = 1; i < nonZeroPerms.size(); ++i) if (nonZeroPerms[i - 1] > nonZeroPerms[i]) return rewriter.notifyMatchFailure(op, "memory layout change"); @@ -1113,6 +719,10 @@ struct TransposeIsReshape final } }; +////////////////////////////////// +// Generic and Elementwise Ops +///////////////////////////////// + /// Check if a `t` is a tensor with zero extents. static std::optional isZeroExtent(Type t) { auto type = dyn_cast(t); @@ -1120,8 +730,8 @@ static std::optional isZeroExtent(Type t) { return std::nullopt; } -// Replace instances of zero extent tensors with empty tensors of the same -// type. +// Replace instances of zero extent tensors with empty tensors +// Pattern: op(X : zero_extent_tensor) -> tensor.empty() struct ZeroExtentTensorCanon final : RewritePattern { ZeroExtentTensorCanon(MLIRContext *context, PatternBenefit benefit) : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} @@ -1220,28 +830,24 @@ struct StablehloAggressiveSimplificationPass final private: FrozenRewritePatternSet patterns; }; + +#include "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.h.inc" } // namespace void populateStablehloCanonicalizationPatterns(MLIRContext *context, RewritePatternSet *patterns, PatternBenefit benefit) { + populateWithGenerated(*patterns); patterns->add< // Arithmetic ops. - AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon, SelectOpCanon, - CompareSelectIntoMinMax, - // Complex ops. - RealOpCanon, ImagOpCanon, - // Query ops. - GetDimensionSizeOpCanon, GetTupleElementOpCanon, - // Broadcast ops. - BroadcastInDimOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic, - ChainedDynamicBroadcastInDimCanonicalization, - DynamicBroadcastInDimAllDimsNonExpanding, + CompareOpCanon, SelectOpCanon, CompareSelectIntoMinMax, + // TODO: Dynamism Refinements, consider merging with canonicalize dynamism + GetDimensionSizeOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic, + DynamicReshapeOpCanon, // Reduce op. - NoopReduceOpCanon, EmptyReduceOpCanon, UnusedResultReduceOpCanon, + ReduceNoopVariableReturn, EmptyReduceOpCanon, UnusedResultReduceOpCanon, // Shape manipulation(-ish) ops. - ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon, GatherOpCanon, - ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape, + GatherOpCanon, TransposeIsReshape, // Types. ZeroExtentTensorCanon>(context, benefit); patterns->add(context); diff --git a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td new file mode 100644 index 00000000000..31f1f475c0e --- /dev/null +++ b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td @@ -0,0 +1,281 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// This is the legalization pattern definition file for CHLO to StableHLO. +// These are included in the populateDecompositionPatterns factory +// and should only include canonical expansions which are not actually +// ambiguous/different for various backends. Avoid patterns that are actually +// lowering to non-canonical forms. + +include "mlir/IR/OpBase.td" +include "stablehlo/dialect/StablehloOps.td" + +//// Utilities +def NotConstantOp : Constraint< + CPred<"llvm::isa($0) || !llvm::isa($0.getDefiningOp())">, + "is not a constant.">; + +def OperandsEqual : Constraint, "operands are equal">; + +def TypesEqual : Constraint, "operands are equal">; + +def NumberOfElementsEqual : Constraint< + CPred<"llvm::cast($0.getType()).getNumElements() == llvm::cast($1.getType()).getNumElements()">, + "same number of elements">; + +def RankEqual : Constraint< + CPred<"llvm::cast($0.getType()).getRank() == llvm::cast($1.getType()).getRank()">, + "same rank">; + +def EmptyI64Array : AttrConstraint< + CPred<"cast($_self).empty()">, "is empty i64 array">; + +def CommutativeOp : Constraint< + CPred<"$0.getDefiningOp()->hasTrait()">, "op is commutative">; + +def AnySplat : AttrConstraint, "is any splat">; + +def AnyZero : AttrConstraint< + CPred<"::mlir::matchPattern($_self, m_AnyAttrOf(m_Zero(), m_AnyZeroFloat()))">, "is int or float zero">; + +def IntZero : AttrConstraint< + CPred<"::mlir::matchPattern($_self, m_Zero())">, "is integer zero">; + +def IntOne : AttrConstraint< + CPred<"::mlir::matchPattern($_self, m_One())">, "is integer one">; + +def IotaDims : AttrConstraint< + CPred<"isIotaRange(cast($_self).asArrayRef())">, "is iota dimensions">; + +def SortedDims : AttrConstraint< + CPred<"llvm::is_sorted(cast($_self).asArrayRef())">, "is sorted dimensions">; + +def AllDimsNonExpanding : Constraint< + CPred<"$0 && cast($0).size() == llvm::cast($1.getType()).getRank()">, "all dims are non-expanding">; + +def GetOperandN : NativeCodeCall<"$0.getDefiningOp()->getOperand($1.getInt())">; + +def GetEmptyI64Array : NativeCodeCall<"$_builder.getDenseI64ArrayAttr({})">; + +def MergeBroadcastDims : NativeCodeCall<"getMergedBroadcastDimensions($_builder, $0, $1)">; + +def StableHLO_ConvertOpWithShape : NativeCodeCall< + "$_builder.create($_loc, $0.getType(), $1)">; + +def StableHLO_ReshapeOpWithShape : NativeCodeCall< + "$_builder.create($_loc, $0.getType(), $1)">; + +class StableHLO_ConstantLike : NativeCodeCall< + "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + +//////////////////////////// +// Generic BinaryOp Patterns + +// op(cst, X) -> op(X, cst) +class CanonicalizeConstantToRhs + : Pat<(StableHLO_OpType:$op (StableHLO_ConstantOp:$lhs $value), $rhs), + (StableHLO_OpType $rhs, $lhs), + [(NotConstantOp $rhs), (CommutativeOp $op)]>; + +//////// +// AddOp + +// Pattern: add(cst, X) -> add(X, cst) +def : CanonicalizeConstantToRhs; + +// Pattern: add(X, 0) -> X +def : Pat<(StableHLO_AddOp $lhs, (ConstantLikeMatcher AnyZero:$value)), + (replaceWithValue $lhs)>; + +//////// +// AndOp + +// Pattern: and(cst, X) -> and(X, cst) +def : CanonicalizeConstantToRhs; + +// Pattern: and(X, 0) -> 0 +def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), + (replaceWithValue $zero)>; + +// Pattern: and(X, 1) -> X +def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$one IntOne:$value)), + (replaceWithValue $lhs)>; + +//////// +// BroadcastInDimOp + +// Pattern: broadcast_in_dim(X, [iota...]) -> X +def : Pat<(StableHLO_BroadcastInDimOp:$op $operand, IotaDims:$dims), + (replaceWithValue $operand), + [(TypesEqual $op, $operand)]>; + +// Pattern: broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB)) +def : Pat<(StableHLO_BroadcastInDimOp + (StableHLO_BroadcastInDimOp $operand, $dims_parent), $dims), + (StableHLO_BroadcastInDimOp $operand, (MergeBroadcastDims $dims, $dims_parent))>; + +// Pattern: broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel] +def : Pat<(StableHLO_BroadcastInDimOp:$op $operand, SortedDims:$dims), + (StableHLO_ReshapeOpWithShape $op, $operand), + [(NumberOfElementsEqual $op, $operand)]>; + +// Pattern: broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank] +def : Pat<(StableHLO_BroadcastInDimOp:$op $operand, $dims), + (StableHLO_TransposeOp $operand, $dims), + [(NumberOfElementsEqual $op, $operand), (RankEqual $op, $operand)]>; + +//////// +// ConvertOp + +// Pattern: convert(X, [X.type]) -> X +def : Pat<(StableHLO_ConvertOp:$convert $operand), + (replaceWithValue $operand), + [(TypesEqual $convert, $operand)]>; + +//////// +// DynamicBroadcastInDimOp + +// Pattern: dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB)) +// TODO: Think more if the values of known_expanding_dimensions and known_non_expanding_dimensions can be preserved. +def : Pat<(StableHLO_DynamicBroadcastInDimOp + (StableHLO_DynamicBroadcastInDimOp $operand, $shape_p, $dims_p, $expanding_p, $nonexpanding_p), + $shape, $dims, $expanding, $nonexpanding), + (StableHLO_DynamicBroadcastInDimOp $operand, $shape, (MergeBroadcastDims $dims, $dims_p), (GetEmptyI64Array), (GetEmptyI64Array))>; + +// Pattern: dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> cast(X) +// No-op, but wrap in ConvertOp to preserve dynamic output shape, can be +// important if this result is returned, where refining type would require +// also updating the funciton signature. +def : Pat<(StableHLO_DynamicBroadcastInDimOp:$op $operand, $shape, $dims, $expanding, $nonexpanding), + (StableHLO_ConvertOpWithShape $op, $operand), + [(AllDimsNonExpanding $nonexpanding, $op)]>; + +// Pattern: dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape) +// If sharing same shape operand, is dynamic reshape. +def : Pat<(StableHLO_DynamicBroadcastInDimOp + (StableHLO_DynamicReshapeOp $operand, $shape), $shape, $dims, $expanding, $nonexpanding), + (StableHLO_DynamicReshapeOp $operand, $shape)>; + + +//////// +// DynamicReshapeOp + +// Pattern: dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape) +def : Pat<(StableHLO_DynamicReshapeOp (StableHLO_DynamicReshapeOp $operand, $shape_p), $shape), + (StableHLO_DynamicReshapeOp $operand, $shape)>; + +//////// +// ImagOp + +// Pattern: imag(complex(R,I)) -> I +def : Pat<(StableHLO_ImagOp (StableHLO_ComplexOp $lhs, $rhs)), + (replaceWithValue $rhs)>; + +//////// +// MaxOp + +// Pattern: max(cst, X) -> max(X, cst) +def : CanonicalizeConstantToRhs; + +//////// +// MinOp + +// Pattern: minimum(cst, X) -> minimum(X, cst) +def : CanonicalizeConstantToRhs; + +//////// +// MulOp + +// Pattern: multiply(cst, X) -> multiply(X, cst) +def : CanonicalizeConstantToRhs; + +// Pattern: multiply(X, 0i) -> 0i +// Multiplication by 0. This fold is not trivial for floats in presence of NaNs +def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), + (replaceWithValue $zero)>; + +// Pattern: multiply(X, 1i) -> X +def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp IntOne:$value)), + (replaceWithValue $lhs)>; + +//////// +// OrOp + +// Pattern: or(cst, X) -> or(X, cst) +def : CanonicalizeConstantToRhs; + +// Pattern: or(X, 1) -> 1 +def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntOne:$value)), + (replaceWithValue $one)>; + +// Pattern: or(X, 0) -> X +def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), + (replaceWithValue $lhs)>; + +//////// +// RealOp + +// Pattern: real(complex(R,I)) -> X +def : Pat<(StableHLO_RealOp (StableHLO_ComplexOp $lhs, $rhs)), + (replaceWithValue $lhs)>; + +//////// +// ReduceOp +// Note: If modifying region is required, must write pattern in C++ + +// Pattern: reduce(X..., dims=[], add) -> X... +def : Pat<(StableHLO_ReduceOp $operands, $init, EmptyI64Array:$dims), + (replaceWithValue $operands)>; + +//////// +// ReshapeOp + +// Pattern: reshape(reshape(X, _), [shape]) -> reshape(X, [shape]) +def : Pat<(StableHLO_ReshapeOp:$reshape (StableHLO_ReshapeOp $operand)), + (StableHLO_ReshapeOpWithShape $reshape, $operand)>; + +// Pattern: reshape(X, [X.shape]) -> X +def : Pat<(StableHLO_ReshapeOp:$reshape $operand), + (replaceWithValue $operand), + [(TypesEqual $reshape, $operand)]>; + +//////// +// SubtractOp + +// Pattern: subtract(X, X) -> 0 +// Must be static shape, otherwise would require broadcasting via CHLO_ConstantLike +def : Pat<(StableHLO_SubtractOp AnyStaticShapeTensor:$operand, $operand), + (StableHLO_ConstantLike<"0"> $operand)>; + +// Pattern: subtract(X, 0) -> X +def : Pat<(StableHLO_SubtractOp $lhs, (StableHLO_ConstantOp AnyZero:$value)), + (replaceWithValue $lhs)>; + +//////// +// TransposeOp + +// Pattern: transpose(X, [iota...]) -> X +def : Pat<(StableHLO_TransposeOp $lhs, IotaDims:$dims), + (replaceWithValue $lhs)>; + +//////// +// GetTupleElementOp + +// Pattern: get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i +def : Pat<(StableHLO_GetTupleElementOp (StableHLO_TupleOp:$tuple $operands), $idx), + (GetOperandN $tuple, $idx)>; + +//////// +// XorOp + +// Pattern: xor(cst, X) -> xor(X, cst) +def : CanonicalizeConstantToRhs; + +// To consider: xor(X, X) -> 0 +// Unclear if this is beneficial on hardware vs adding another constant +// +// def : Pat<(StableHLO_XorOp AnyStaticShapeTensor:$operand, $operand), +// (StableHLO_ConstantLike<"0"> $operand)>;