diff --git a/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir index 76aa10e488..0f46888d5a 100644 --- a/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +++ b/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir @@ -69,7 +69,7 @@ func.func @gather_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor index_vector_dim = 3 >, slice_sizes = array, - indices_are_sorted = true + indices_are_sorted = false } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> func.return %0 : tensor<4x3x5x8xi32> } @@ -77,9 +77,9 @@ func.func @gather_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor // ----- // CHECK-LABEL: @gather_with_batching_no_index_vector_dim +// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> // CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ // CHECK-SAME: dimension_numbers = #stablehlo.gather< @@ -102,7 +102,7 @@ func.func @gather_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, index_vector_dim = 3 >, slice_sizes = array, - indices_are_sorted = true + indices_are_sorted = false }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> func.return %0 : tensor<4x3x5x8xi32> } @@ -133,13 +133,309 @@ func.func @gather_with_batching_dim_size_zero(%arg0: tensor<0x2x9xi32>, %arg1: t index_vector_dim = 3 >, slice_sizes = array, - indices_are_sorted = true + indices_are_sorted = false }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> func.return %0 : tensor<0x3x5x8xi32> } // ----- +// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<3x4x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 1 : tensor<3x4x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<3x4x5x1xi32>, tensor<3x4x5x1xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x4xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], +// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x4xi32>) -> tensor<3x4x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<3x4x5x8xi32> +func.func @gather_batching_dims_indices_become_unsorted(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1, 3], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [0, 1], + start_index_map = [1, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> + func.return %0 : tensor<3x4x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted_2 +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<2x3x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], +// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> +func.func @gather_batching_dims_indices_become_unsorted_2(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [2, 3], + operand_batching_dims = [0, 1], + start_indices_batching_dims = [1, 0], + start_index_map = [2, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> + func.return %0 : tensor<2x3x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_batching_dims_indices_remain_sorted +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], +// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = true, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> +func.func @gather_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [2, 3], + operand_batching_dims = [0, 1], + start_indices_batching_dims = [0, 2], + start_index_map = [2, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> + func.return %0 : tensor<2x3x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_batching_dims_indices_remain_unsorted +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], +// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> +func.func @gather_batching_dims_indices_remain_unsorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [2, 3], + operand_batching_dims = [0, 1], + start_indices_batching_dims = [0, 2], + start_index_map = [2, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> + func.return %0 : tensor<2x3x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_batching_dims_does_not_overflow_indices_type +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x127x5x1xi8> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x127x5x1xi8> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x127x5x1xi8>, tensor<4x127x5x1xi8>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x4xi8> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], +// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x4xi8>) -> tensor<4x127x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<4x127x5x8xi32> +func.func @gather_batching_dims_does_not_overflow_indices_type(%arg0: tensor<127x2x4x7x9xi32>, %arg1: tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1, 3], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [1, 0], + start_index_map = [1, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> + func.return %0 : tensor<4x127x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_batching_dim_overflows_signless_indices_type +// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5x2xi8>) -> tensor<4x128x5x2xi32> +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x2xi32>) -> tensor<4x128x5x4xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], +// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x4xi32>) -> tensor<4x128x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> +func.func @gather_batching_dim_overflows_signless_indices_type(%arg0: tensor<128x2x4x7x9xi32>, %arg1: tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1, 3], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [1, 0], + start_index_map = [1, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> + func.return %0 : tensor<4x128x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_batching_dim_overflows_unsigned_indices_type +// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<256x4x5x2xui8>) -> tensor<256x4x5x2xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<256x4x5x1xi32> +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<256x4x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim0]], %[[iota_dim1]], %[[convert]], dim = 3 : (tensor<256x4x5x1xi32>, tensor<256x4x5x1xi32>, tensor<256x4x5x2xi32>) -> tensor<256x4x5x4xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], +// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x4xi32>) -> tensor<256x4x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<256x4x5x8xi32> +func.func @gather_batching_dim_overflows_unsigned_indices_type(%arg0: tensor<256x2x4x7x9xi32>, %arg1: tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1, 3], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [0, 1], + start_index_map = [1, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> + func.return %0 : tensor<256x4x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_batching_dim_overflows_indices_type_and_i32 +// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x2xi64> +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x2147483648x5x1xi64> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x2147483648x5x1xi64> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x2xi64>) -> tensor<4x2147483648x5x4xi64> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], +// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x4xi64>) -> tensor<4x2147483648x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<4x2147483648x5x8xi32> +func.func @gather_batching_dim_overflows_indices_type_and_i32(%arg0: tensor<2147483648x2x4x7x9xi32>, %arg1: tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1, 3], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [1, 0], + start_index_map = [1, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> + func.return %0 : tensor<4x2147483648x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_batching_dim_dynamic_size +// CHECK: operand_batching_dims = [0, 2] +// CHECK: start_indices_batching_dims = [1, 0] +func.func @gather_batching_dim_dynamic_size(%arg0: tensor, %arg1: tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1, 3], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [1, 0], + start_index_map = [1, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = false + } : (tensor, tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> + func.return %0 : tensor<4x?x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_batching_dim_overflows_and_no_index_vector_dim +// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5xi8>) -> tensor<4x128x5xi32> +// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %[[convert]] : (tensor<4x128x5xi32>) -> tensor<4x128x5x1xi32> +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>) -> tensor<4x128x5x3xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], +// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<128x2x4x9xi32>, tensor<4x128x5x3xi32>) -> tensor<4x128x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> +func.func @gather_batching_dim_overflows_and_no_index_vector_dim(%arg0: tensor<128x2x4x9xi32>, %arg1: tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [1, 0], + start_index_map = [1], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<128x2x4x9xi32>, tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> + func.return %0 : tensor<4x128x5x8xi32> +} + +// ----- + // CHECK-LABEL: @scatter_with_batching_dims // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> @@ -156,7 +452,7 @@ func.func @scatter_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tenso // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ - indices_are_sorted = true, + indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [3], inserted_window_dims = [1, 3], @@ -176,9 +472,9 @@ func.func @scatter_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tenso // ----- // CHECK-LABEL: @scatter_with_batching_no_index_vector_dim +// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> // CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ // CHECK-SAME: indices_are_sorted = false, @@ -192,7 +488,7 @@ func.func @scatter_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ - indices_are_sorted = true, + indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [3], inserted_window_dims = [1], @@ -208,3 +504,60 @@ func.func @scatter_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> func.return %0 : tensor<3x2x4x9xi32> } + +// ----- + +// CHECK-LABEL: @scatter_batching_dims_indices_remain_sorted +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> +// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ +// CHECK-SAME: indices_are_sorted = true, +// CHECK-SAME: dimension_numbers = #stablehlo.scatter< +// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], +// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1, 2, 3], index_vector_dim = 3>, +// CHECK-SAME: unique_indices = false}> +// CHECK: (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> +// CHECK-NEXT: return %[[scatter]] : tensor<2x5x4x7x9xi32> +func.func @scatter_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>, %arg2: tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> { + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ + indices_are_sorted = true, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [2, 3], + input_batching_dims = [0, 1], + scatter_indices_batching_dims = [0, 2], + scatter_dims_to_operand_dims = [2, 3], + index_vector_dim = 3 + >, + unique_indices = false + }> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> + func.return %0 : tensor<2x5x4x7x9xi32> +} + +// ----- + +// CHECK-LABEL: @scatter_batching_dim_dynamic_scatter_indices +// CHECK: input_batching_dims = [0, 2] +// CHECK: scatter_indices_batching_dims = [1, 0] +func.func @scatter_batching_dim_dynamic_scatter_indices(%arg0: tensor, %arg1: tensor<4x?x5x2xi32>, %arg2: tensor<4x?x5x8xi32>) -> tensor { + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ + indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1, 3], + input_batching_dims = [0, 2], + scatter_indices_batching_dims = [1, 0], + scatter_dims_to_operand_dims = [1, 3], + index_vector_dim = 3 + >, + unique_indices = false + }> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor, tensor<4x?x5x2xi32>, tensor<4x?x5x8xi32>) -> tensor + func.return %0 : tensor +} diff --git a/stablehlo/transforms/StablehloCompatibilityExpander.cpp b/stablehlo/transforms/StablehloCompatibilityExpander.cpp index 03bb810067..e883411290 100644 --- a/stablehlo/transforms/StablehloCompatibilityExpander.cpp +++ b/stablehlo/transforms/StablehloCompatibilityExpander.cpp @@ -22,8 +22,11 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" @@ -75,6 +78,42 @@ SmallVector mergeSortedDims(ArrayRef dims1, return result; } +bool fitsInIntegralType(int64_t size, IntegerType type) { + if (type.isUnsigned()) { + return llvm::isUIntN(type.getWidth(), size); + } else { + return llvm::isIntN(type.getWidth(), size); + } +} + +// If `type` is an integer type in which `size` doesn't fit, promote it to i32 +// or i64 (depending on `size`). +Type promoteTypeForSize(Type type, int64_t size, OpBuilder &builder) { + // Gather/Scatter should have an integer type, but we check just in case. + auto intType = dyn_cast(type); + if (!intType || fitsInIntegralType(size, intType)) { + return type; + } + if (fitsInIntegralType(size, builder.getI32Type())) { + return builder.getI32Type(); + } + return builder.getI64Type(); +} + +// If `indices_batching_dims` and `updated_index_map` are both sorted, then the +// `indices_are_sorted` property is preserved. +// +// This is because each concatenated iota is monotonically increasing, sorted +// indices batching dims mean their order corresponds to the order of batching +// dims in the operand, and a sorted updated start index map means the order of +// the index vector dim corresponds to the order of operand dims. +bool getUpdatedIndicesAreSorted(bool indices_are_sorted, + ArrayRef indices_batching_dims, + ArrayRef updated_index_map) { + return indices_are_sorted && llvm::is_sorted(indices_batching_dims) && + llvm::is_sorted(updated_index_map); +} + // Returns an updated indices tensor such that an `IotaOp` is prepended for each // dim in `indicesBatchingDims` with a `ConcatenateOp`. // @@ -85,16 +124,31 @@ Value createConcatIndices(Value indices, int64_t indexVectorDim, PatternRewriter &rewriter) { Location loc = indices.getLoc(); auto indicesType = cast(indices.getType()); - bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); + Type elementType = indicesType.getElementType(); + // The batching dim sizes might not fit in the existing element type, + // in which case we need to promote it. + for (int64_t batchingDim : indicesBatchingDims) { + elementType = promoteTypeForSize( + elementType, indicesType.getDimSize(batchingDim), rewriter); + } + if (elementType != indicesType.getElementType()) { + indicesType = RankedTensorType::get(indicesType.getShape(), elementType); + indices = rewriter.create(loc, indicesType, indices); + } + + bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); SmallVector iotaShape(indicesType.getShape()); if (indexVectorDimOnLastDim) { iotaShape.push_back(1); } else { iotaShape[indexVectorDim] = 1; } - auto iotaType = - RankedTensorType::get(iotaShape, indicesType.getElementType()); + auto iotaType = RankedTensorType::get(iotaShape, elementType); + + if (indexVectorDimOnLastDim) { + indices = rewriter.create(loc, iotaType, indices); + } SmallVector indicesToConcat; indicesToConcat.reserve(indicesBatchingDims.size() + 1); @@ -102,12 +156,7 @@ Value createConcatIndices(Value indices, int64_t indexVectorDim, indicesToConcat.push_back( rewriter.create(loc, iotaType, batchingDim)); } - if (indexVectorDimOnLastDim) { - indicesToConcat.push_back( - rewriter.create(loc, iotaType, indices)); - } else { - indicesToConcat.push_back(indices); - } + indicesToConcat.push_back(indices); return rewriter.create(loc, indicesToConcat, indexVectorDim); } @@ -125,26 +174,37 @@ class GatherWithBatchingDimsExpander : public OpRewritePattern { PatternRewriter &rewriter) const override { GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); - if (operandBatchingDims.empty()) + ArrayRef startIndicesBatchingDims = + dimNumbers.getStartIndicesBatchingDims(); + if (operandBatchingDims.empty()) { return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { diag << "gather op has no batching dims"; }); + } + + if (!op.getStartIndices().getType().hasStaticShape()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "gather op has start indices with dynamic shape, can't expand"; + }); + } SmallVector newCollapsedSliceDims = mergeSortedDims( operandBatchingDims, dimNumbers.getCollapsedSliceDims()); SmallVector newStartIndexMap = llvm::to_vector(llvm::concat( operandBatchingDims, dimNumbers.getStartIndexMap())); - Value newIndices = createConcatIndices( - op.getStartIndices(), dimNumbers.getIndexVectorDim(), - dimNumbers.getStartIndicesBatchingDims(), rewriter); + Value newIndices = createConcatIndices(op.getStartIndices(), + dimNumbers.getIndexVectorDim(), + startIndicesBatchingDims, rewriter); rewriter.replaceOpWithNewOp( op, op.getOperand(), newIndices, GatherDimensionNumbersAttr::get( op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, newStartIndexMap, dimNumbers.getIndexVectorDim()), - op.getSliceSizes(), /*indicesAreSorted=*/false); + op.getSliceSizes(), + getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), + startIndicesBatchingDims, newStartIndexMap)); return success(); } @@ -160,10 +220,19 @@ class ScatterWithBatchingDimsExpander : public OpRewritePattern { PatternRewriter &rewriter) const override { ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); - if (inputBatchingDims.empty()) + ArrayRef scatterIndicesBatchingDims = + dimNumbers.getScatterIndicesBatchingDims(); + if (inputBatchingDims.empty()) { return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { diag << "scatter op has no batching dims"; }); + } + + if (!op.getScatterIndices().getType().hasStaticShape()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "gather op has start indices with dynamic shape, can't expand"; + }); + } SmallVector newInsertedWindowDims = mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims()); @@ -172,7 +241,7 @@ class ScatterWithBatchingDimsExpander : public OpRewritePattern { inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); Value newIndices = createConcatIndices( op.getScatterIndices(), dimNumbers.getIndexVectorDim(), - dimNumbers.getScatterIndicesBatchingDims(), rewriter); + scatterIndicesBatchingDims, rewriter); auto newScatterOp = rewriter.create( op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, op.getUpdates(), @@ -181,7 +250,10 @@ class ScatterWithBatchingDimsExpander : public OpRewritePattern { newInsertedWindowDims, /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), - /*indicesAreSorted=*/false, op.getUniqueIndices()); + getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), + scatterIndicesBatchingDims, + newScatterDimsToOperandDims), + op.getUniqueIndices()); newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); rewriter.replaceOp(op, newScatterOp.getResults()); diff --git a/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp deleted file mode 100644 index 6d5d19eb56..0000000000 --- a/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +++ /dev/null @@ -1,258 +0,0 @@ -/* Copyright 2024 The StableHLO Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include -#include -#include -#include -#include - -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/dialect/Version.h" -#include "stablehlo/transforms/PassUtils.h" -#include "stablehlo/transforms/Passes.h" - -namespace mlir { -namespace stablehlo { -#define GEN_PASS_DEF_STABLEHLOCOMPATIBILITYEXPANDERPASS -#include "stablehlo/transforms/Passes.h.inc" - -namespace { - -//===----------------------------------------------------------------------===// -// Helpers. -//===----------------------------------------------------------------------===// - -// Check user-specified target version. -vhlo::Version validateTargetVersion(llvm::StringRef versionRef) { - auto failOrVersion = vhlo::Version::fromString(versionRef); - if (failed(failOrVersion)) { - assert(!versionRef.empty() && - "No target version specified. Target version must be of the form " - "`#.#.#`."); - assert(versionRef.empty() && - "Invalid target version argument. Target version must be of the " - "form `#.#.#`."); - } - vhlo::Version targetVersion = *failOrVersion; - assert((vhlo::Version::getMinimumVersion() <= targetVersion) && - "target version is less than minimum supported."); - assert((targetVersion <= vhlo::Version::getCurrentVersion()) && - "target version is greater than current version."); - return targetVersion; -} - -SmallVector mergeSortedDims(ArrayRef dims1, - ArrayRef dims2) { - SmallVector result; - result.reserve(dims1.size() + dims2.size()); - std::merge(dims1.begin(), dims1.end(), dims2.begin(), dims2.end(), - std::back_inserter(result)); - return result; -} - -// Returns an updated indices tensor such that an `IotaOp` is prepended for each -// dim in `indicesBatchingDims` with a `ConcatenateOp`. -// -// If `indexVectorDim` is equal to the rank of `indices`, it is reshaped to have -// a trailing dimension of size 1 so it can be concatenated with the `IotaOp`s. -Value createConcatIndices(Value indices, int64_t indexVectorDim, - ArrayRef indicesBatchingDims, - PatternRewriter &rewriter) { - Location loc = indices.getLoc(); - auto indicesType = cast(indices.getType()); - bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); - - SmallVector iotaShape(indicesType.getShape()); - if (indexVectorDimOnLastDim) { - iotaShape.push_back(1); - } else { - iotaShape[indexVectorDim] = 1; - } - auto iotaType = - RankedTensorType::get(iotaShape, indicesType.getElementType()); - - SmallVector indicesToConcat; - indicesToConcat.reserve(indicesBatchingDims.size() + 1); - for (int64_t batchingDim : indicesBatchingDims) { - indicesToConcat.push_back( - rewriter.create(loc, iotaType, batchingDim)); - } - if (indexVectorDimOnLastDim) { - indicesToConcat.push_back( - rewriter.create(loc, iotaType, indices)); - } else { - indicesToConcat.push_back(indices); - } - return rewriter.create(loc, indicesToConcat, indexVectorDim); -} - -//===----------------------------------------------------------------------===// -// Patterns (non DRR) -//===----------------------------------------------------------------------===// - -// Converts a `GatherOp` with batching dims to a `GatherOp` without batching -// dims, such that each batching dim becomes a collapsed slice dim with a -// corresponding `IotaOp` concatenated to the start indices. -class GatherWithBatchingDimsExpander : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GatherOp op, - PatternRewriter &rewriter) const override { - GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); - ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); - if (operandBatchingDims.empty()) { - return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { - diag << "gather op has no batching dims"; - }); - } - - SmallVector newCollapsedSliceDims = mergeSortedDims( - operandBatchingDims, dimNumbers.getCollapsedSliceDims()); - SmallVector newStartIndexMap = - llvm::to_vector(llvm::concat( - operandBatchingDims, dimNumbers.getStartIndexMap())); - Value newIndices = createConcatIndices( - op.getStartIndices(), dimNumbers.getIndexVectorDim(), - dimNumbers.getStartIndicesBatchingDims(), rewriter); - rewriter.replaceOpWithNewOp( - op, op.getOperand(), newIndices, - GatherDimensionNumbersAttr::get( - op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, - /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, - newStartIndexMap, dimNumbers.getIndexVectorDim()), - op.getSliceSizes(), /*indicesAreSorted=*/false); - - return success(); - } -}; - -// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching -// dims, such that each batching dim becomes an inserted window dim with a -// corresponding `IotaOp` concatenated to the scatter indices. -class ScatterWithBatchingDimsExpander : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ScatterOp op, - PatternRewriter &rewriter) const override { - ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); - ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); - if (inputBatchingDims.empty()) { - return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { - diag << "scatter op has no batching dims"; - }); - } - - SmallVector newInsertedWindowDims = - mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims()); - SmallVector newScatterDimsToOperandDims = - llvm::to_vector(llvm::concat( - inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); - Value newIndices = createConcatIndices( - op.getScatterIndices(), dimNumbers.getIndexVectorDim(), - dimNumbers.getScatterIndicesBatchingDims(), rewriter); - auto newScatterOp = rewriter.create( - op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, - op.getUpdates(), - ScatterDimensionNumbersAttr::get( - op.getContext(), dimNumbers.getUpdateWindowDims(), - newInsertedWindowDims, - /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, - newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), - /*indicesAreSorted=*/false, op.getUniqueIndices()); - - newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); - rewriter.replaceOp(op, newScatterOp.getResults()); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Pass -//===----------------------------------------------------------------------===// - -struct StablehloCompatibilityExpanderPass - : public impl::StablehloCompatibilityExpanderPassBase< - StablehloCompatibilityExpanderPass> { - StablehloCompatibilityExpanderPass() - : StablehloCompatibilityExpanderPassBase< - StablehloCompatibilityExpanderPass>() {} - StablehloCompatibilityExpanderPass( - const StablehloCompatibilityExpanderPassOptions &opts) - : StablehloCompatibilityExpanderPassBase< - StablehloCompatibilityExpanderPass>(opts) {} - - public: - LogicalResult initialize(MLIRContext *context) override { - auto targetVersion = validateTargetVersion(targetVersionOption); - - config.useTopDownTraversal = true; - RewritePatternSet patterns_(context); - populateStablehloCompatibilityExpanderPatterns(&patterns_, context, - targetVersion); - patterns = std::move(patterns_); - return success(); - } - - void runOnOperation() override { - auto func = getOperation(); - if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { - func.emitError( - "Failed to converge StableHLOCompatibilityExpanderPass in ") - << config.maxIterations << " iterations"; - signalPassFailure(); - } - } - - private: - FrozenRewritePatternSet patterns; - GreedyRewriteConfig config; -}; - -#include "stablehlo/transforms/StablehloCompatibilityExpanderPatterns.h.inc" - -} // namespace - -void populateStablehloCompatibilityExpanderPatterns( - RewritePatternSet *patterns, MLIRContext *context, - vhlo::Version targetVersion) { - // StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0. - if (targetVersion < vhlo::Version(1, 1, 0)) { - patterns - ->add( - context); - } - // StableHLO TanOp is introduced in v1.4.0. - if (targetVersion < vhlo::Version(1, 4, 0)) { - patterns->add(context); - } -} - -} // namespace stablehlo -} // namespace mlir