Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx.Unique #647

Closed
Tracked by #215
Peefy opened this issue Apr 23, 2024 · 8 comments
Closed
Tracked by #215

onnx.Unique #647

Peefy opened this issue Apr 23, 2024 · 8 comments

Comments

@Peefy
Copy link

Peefy commented Apr 23, 2024

Tracking Issue: #215

@Peefy
Copy link
Author

Peefy commented Apr 23, 2024

Hello @renxida Could you please assign this issue to me? Also, I have some questions to ask you

@renxida
Copy link
Contributor

renxida commented Apr 24, 2024

Should we add a new AtenUniqueOp or use AtenUniqueConsecutiveOp to compose the torch.unique and onnx.Unique Op?

That sounds like a plan! feel free to reach me on discord (@xida_ren) if you have questions and just want to chat. also ping me on discord if you need e.g. a code review or a ci approval

@Peefy
Copy link
Author

Peefy commented Apr 24, 2024

Thank you for your kind reply. ❤️

@vivekkhandelwal1
Copy link
Contributor

Hi @Peefy, are you still working on this op?

@Peefy
Copy link
Author

Peefy commented May 14, 2024

Hi @Peefy, are you still working on this op?

Hello @vivekkhandelwal1.

Yes, here's my code. I have almost completed the situation where onnx.Unique sorted attribute is true.

But I encountered some other difficulties. For onnx.Unique then its sorted attribute is false, it seems that there is no non sorted normal deduplication function in torch. Therefore, I am thinking about how to use torch.unique_consecutive to combine it. Can you give me some guidance or tips?

  patterns.onOp(
      "Unique", 11,
      [](OpBinder binder, ConversionPatternRewriter &rewriter) {
        // Here we use torch.unique_consecutive and other operators to compose the onnx.Unique
        // ```python
        // def onnx_unique(x, sorted=True, dim=0):
        //     unique, inverse, counts = torch_unique(x, dim=dim, 
        //         sorted=sorted, return_inverse=True, return_counts=True)
        //     _, ind_sorted = torch.sort(idx, stable=True)
        //     cum_sum = counts.cumsum(0)
        //     cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
        //     indicies = ind_sorted[cum_sum]
        //     return unique, indicies, inverse, counts
        //
        // def torch_unique(tensor):
        //     sorted_tensor, sorted_indices = torch.sort(tensor)
        //     return torch.unique_consecutive(sorted_tensor, return_inverse=True, return_counts=True)
        // ```
        // Note that the situation where sorted is false has not been handled yet.
        //
        // Reference: onnx.Unique: https://onnx.ai/onnx/operators/onnx__Unique.html
        // Reference: torch.unique: https://pytorch.org/docs/stable/generated/torch.unique.html
        // Reference: torch.unique_consecutive: https://pytorch.org/docs/stable/generated/torch.unique_consecutive.html
        Torch::ValueTensorType outputType, indicesType, inverseIndicesType, countsType;
        Value input;
        // Note the axis can be negative.
        // Accepted range is [-r, r-1] where r = rank(input).
        int64_t axis;
        // The default value of sorted attribute is 1
        bool sorted;
        if (binder.tensorOperand(input) ||
            binder.s64BoolAttr(sorted, "sorted", true) ||
            binder.s64IntegerAttr(axis, "axis", 0) ||
            binder.tensorResultTypeAtIndex(outputType, 0) ||
            binder.tensorResultTypeAtIndex(indicesType, 1) ||
            binder.tensorResultTypeAtIndex(inverseIndicesType, 2) ||
            binder.tensorResultTypeAtIndex(countsType, 3))
          return failure();
        std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
        if (!maybeRank)
          return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor");
        if (!sorted)
          return rewriter.notifyMatchFailure(binder.op, "Unimplemented: torch.unique is not yet supported for situations where sorted is false");
        unsigned rank = *maybeRank;
        axis = Torch::toPositiveDim(axis, rank);
        auto loc = binder.getLoc();
        auto torchUnique = [&](Value tensor, Value sorted, Value dim) -> std::tuple<Value, Value, Value> {
          Value cstFalse =
            rewriter.create<Torch::ConstantBoolOp>(loc, true);
          Value zero = rewriter.create<Torch::ConstantIntOp>(
            binder.getLoc(), rewriter.getType<Torch::IntType>(),
            rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
          auto sortedResult = rewriter.create<Torch::AtenSortOp>(loc, tensor.getType(), indicesType, tensor, zero, cstFalse);
          // Index 0 is the sorted tensor and index 1 is the indices
          Value sortedTensor = sortedResult->getResult(0);
          Value cstReturnInverse =
            rewriter.create<Torch::ConstantBoolOp>(loc, true);
          Value cstReturnCounts =
            rewriter.create<Torch::ConstantBoolOp>(loc, true);
          auto uniqueConsecutiveResult = rewriter.create<Torch::AtenUniqueConsecutiveOp>(loc, outputType, indicesType, countsType, sortedTensor, cstReturnInverse, cstReturnCounts, dim);
          Value uniqueValues = uniqueConsecutiveResult->getResult(0);
          Value inverseIndices = uniqueConsecutiveResult->getResult(1);
          Value unique_counts = uniqueConsecutiveResult->getResult(2);
          return std::make_tuple(uniqueValues, inverseIndices, unique_counts);
        };

        auto onnxUnique = [&](Value unique, Value inverse, Value counts) -> ValueRange {
          auto cstFalse =
            rewriter.create<Torch::ConstantBoolOp>(loc, true);
          auto zero = rewriter.create<Torch::ConstantIntOp>(
            binder.getLoc(), rewriter.getType<Torch::IntType>(),
            rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
          auto one = rewriter.create<Torch::ConstantIntOp>(
              loc, rewriter.getI64IntegerAttr(1));
          auto intMinus1 = rewriter.create<Torch::ConstantIntOp>(
            binder.getLoc(), rewriter.getType<Torch::IntType>(),
            rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1));
          auto sortedResult = rewriter.create<Torch::AtenSortOp>(loc, inverse.getType(), indicesType, inverse, zero, cstFalse);
          auto ind_sorted = sortedResult->getResult(1);

          // %int1 = torch.constant.int 1
          // %size = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
          // %none = torch.constant.none
          // %1 = torch.aten.zeros %size, %none, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1],f32>
          // %int0 = torch.constant.int 0
          // %none_0 = torch.constant.none
          // %2 = torch.aten.cumsum %counts, %int0, %none_0 : !torch.vtensor<[4],f32>, !torch.int, !torch.none -> !torch.vtensor<[4],f32>
          // %3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[1],f32>, !torch.vtensor<[4],f32>) -> !torch.list<vtensor>
          // %int0_1 = torch.constant.int 0
          // %4 = torch.aten.cat %3, %int0_1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[5],f32>
          // %int0_2 = torch.constant.int 0
          // %int0_3 = torch.constant.int 0
          // %int-1 = torch.constant.int -1
          // %int1_4 = torch.constant.int 1
          // %5 = torch.aten.slice.Tensor %4, %int0_2, %int0_3, %int-1, %int1_4 : !torch.vtensor<[5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4],f32>

          auto none = rewriter.create<Torch::ConstantNoneOp>(loc);
          auto cumSumResult = rewriter.create<Torch::AtenCumsumOp>(loc, counts.getType(), counts, zero, none);
          auto cumSumSliceResult = rewriter.create<Torch::AtenSliceTensorOp>(loc, indicesType, cumSumResult, zero, zero, intMinus1, one);

          auto size = rewriter.create<Torch::ConstantIntOp>(
            loc, rewriter.getType<Torch::IntType>(),
            rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
          SmallVector<Value> sizeList;
          sizeList.push_back(size);
          auto sizeValue = rewriter.create<Torch::PrimListConstructOp>(
              loc,
              Torch::ListType::get(
                  Torch::IntType::get(binder.op->getContext())),
              sizeList);
          auto tensorZerosResult = rewriter.create<Torch::AtenZerosOp>(loc, counts.getType(), sizeValue, none, none, none, none);
          SmallVector<Value> valueList;
          valueList.push_back(tensorZerosResult);
          valueList.push_back(cumSumSliceResult);
          Type listElemType =
              tensorZerosResult
                  .getType()
                  .cast<Torch::BaseTensorType>()
                  .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
                                        /*optionalDtype=*/nullptr);
          Type listType = Torch::ListType::get(listElemType);
          Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
              loc, listType, valueList);

          auto catOpResult = rewriter.create<Torch::AtenCatOp>(loc, indicesType, tensorList, zero);

          auto select = [&](Value v, Value k) -> Value {
            auto ty = v.getType().cast<Torch::ValueTensorType>();
            auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
                loc,
                Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
                                            ty.getOptionalDtype()),
                v, zero, k);
            Value item = rewriter.create<Torch::AtenItemOp>(
                loc, rewriter.getType<Torch::IntType>(), sel);
            return item;
          };
          auto indicies = select(ind_sorted, catOpResult);

          return ValueRange({unique, indicies, inverse, counts});
        };

        rewriter.replaceOp(binder.op, std::apply(onnxUnique, torchUnique(
          input,
          rewriter.create<Torch::ConstantBoolOp>(loc, sorted),
          rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(axis)))));
        return success();
      });

@vivekkhandelwal1
Copy link
Contributor

Hi @Peefy, as of now you can add a limited support for the op, and then extend it later. Also, it will be better if you create a WIP pr for this to be reviewed.

@Peefy
Copy link
Author

Peefy commented Jul 2, 2024

Hello @vivekkhandelwal1 Sorry, I may not have much time to complete this recently. Please un-assign me.

@vivekkhandelwal1
Copy link
Contributor

Assigning this to @vinayakdsci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants