Skip to content

Commit

Permalink
[TIR] Support Vector Reinterpret Calls (apache#16673)
Browse files Browse the repository at this point in the history
This PR adds support for vector reinterpret calls in TIR.
  • Loading branch information
Hzfengsy authored and thaisacs committed Apr 3, 2024
1 parent 43ab73f commit fb506fd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
14 changes: 13 additions & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#include <tvm/tir/transform.h>

#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace tvm {
Expand Down Expand Up @@ -319,6 +318,17 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
}
}
// Reinterpret expr
PrimExpr MutateReinterpretExpr_(const CallNode* op) {
ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op);
} else {
int lanes = value.dtype().lanes();
return Call(op->dtype.with_lanes(lanes), op->op, {value});
}
}
// Call
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::if_then_else())) {
Expand All @@ -337,6 +347,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
Array<PrimExpr> mutated_value = MutateArray(value, &lane);
Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]};
return Call(op->dtype.with_lanes(lane), op->op, new_args);
} else if (op->op.same_as(builtin::reinterpret())) {
return MutateReinterpretExpr_(op);
}
auto optional_op = op->op.as<Op>();
bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false);
Expand Down
31 changes: 22 additions & 9 deletions tests/python/tir-transform/test_tir_transform_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import te
from tvm.script import ir as I
from tvm.script import tir as T


def test_vectorize_loop():
Expand Down Expand Up @@ -226,13 +229,23 @@ def test_vectorize_dtype_mismatch():
tvm.lower(s, [A], "llvm", simple_mode=True)


def test_vectorize_with_reinterpret():
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
for i in T.vectorized(0, 16):
B[i] = T.reinterpret("float32", A[i])

@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
B[0:16] = T.reinterpret("float32x16", A[0:16])

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


if __name__ == "__main__":
test_vectorize_vector()
test_vectorize_with_if()
test_vectorize_loop()
test_vectorize_if_then_else()
test_vectorize_with_le_cond()
test_vectorize_with_ge_cond()
test_vectorize_let()
test_vectorize_while_fail()
test_vectorize_dtype_mismatch()
tvm.testing.main()

0 comments on commit fb506fd

Please sign in to comment.