From 00a994ca71c2b9440c90166089eb86f3d4ae51ce Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 17 Dec 2024 11:39:00 +0000 Subject: [PATCH] fix slice op sym infer with data input and symbol starts or ends --- .../infer_sym_slice_utils.h | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h index c584b8306b854..80c532953facc 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h @@ -236,19 +236,9 @@ inline ShapeOrData SliceRawInferSymbolicShape( // Currently, we DO NOT support the case that any element in `axes` `starts` // or `ends` is a Symbol. auto vec_int64 = details::VecExpr2Int64(starts); - PADDLE_ENFORCE_EQ( - vec_int64.has_value(), - true, - common::errors::InvalidArgument( - "for slice op, all the elements in `starts` must be int64_t")); std::vector starts_int = vec_int64.value(); vec_int64 = details::VecExpr2Int64(ends); - PADDLE_ENFORCE_EQ( - vec_int64.has_value(), - true, - common::errors::InvalidArgument( - "for slice op, all the elements in `ends` must be int64_t")); std::vector ends_int = vec_int64.value(); const int64_t start = @@ -274,10 +264,18 @@ inline ShapeOrData SliceRawInferSymbolicShape( return symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(shape, out_data)}; }; + bool starts_ends_all_int = + std::all_of(starts_expr.begin(), + starts_expr.end(), + [](const symbol::DimExpr &e) { return e.isa(); }) && + std::all_of(ends_expr.begin(), + ends_expr.end(), + [](const symbol::DimExpr &e) { return e.isa(); }); - const auto &out_shape = in_shapeordata.data().has_value() - ? GetDataDimExprs() - : GetShapeDimExprs(); + const auto &out_shape = + in_shapeordata.data().has_value() && starts_ends_all_int + ? GetDataDimExprs() + : GetShapeDimExprs(); if (out_shape.data().has_value() && out_shape.shape().empty()) { // 0D tensor const paddle::dialect::DenseTensorType &tensor_type = out.type().dyn_cast();