diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 1ea047cad1f8..aa560402877f 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -829,7 +829,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "lhs and rhs of convolution must either be both int or fp"); } - if (inputZp && weightZp && !isa(bias.getType())) { + if (inputZp && !isa(bias.getType())) { auto biasDTy = cast(bias.getType()).getElementType(); if (!biasDTy.isInteger(32)) { return rewriter.notifyMatchFailure( @@ -1123,7 +1123,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // - grouped 1d-3d // - grouped 1d-3d (quantized) // - ungrouped 1d-3d - if (groupSize == 1 && !inputZp && !weightZp) { + if (groupSize == 1 && !inputZp) { switch (numSpatialDims) { case 1: conv = rewriter @@ -1164,7 +1164,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } - if (groupSize == 1 && inputZp && weightZp) { + if (groupSize == 1 && inputZp) { // The quantized version uses a different channel ordering so we need to // permute the tensors in order to use the existing path. We should // eventually directly support this channel ordering. @@ -1224,10 +1224,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } - if (inputZp || weightZp) - return rewriter.notifyMatchFailure( - op, "unimplemented: quantized grouped convolutions"); - if (numSpatialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D grouped convolution supported"); @@ -1238,7 +1234,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto weightShape = makeShapeTorchCompatible( cast(weight.getType()).getShape()); if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && - weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { + weightShape[0] % inShape[1] == 0 && weightShape[1] == 1 && !inputZp) { // Collapse weight shape SmallVector collapsedDims = {{0, 1}, {2}, {3}}; SmallVector collapsedShape{ @@ -1325,13 +1321,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto expandOutputTensor = expandGroups(outputTensor, 1); // TODO: add 1D and 3D case - conv = rewriter - .create( - loc, expandOutputTensor.getResultType(), - ValueRange{paddedInputExpanded, weightExpanded}, - expandOutputTensor.getResult(), stridesAttr, dilationAttr) - .getResult(0); - + if (!inputZp) { + conv = rewriter + .create( + loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weightExpanded}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); + } else { + conv = rewriter + .create( + loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weightExpanded, inputZp, + weightZp}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); + } conv = rewriter.create( loc, outputTensor.getType(), conv, expandOutputTensor.getReassociationIndices()); diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 38bc4d275bf1..5925dd07e185 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -378,7 +378,7 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, - QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, QuantizeAccumulator, QuantizeAccumulator, QuantizeResultLikeOperand, QuantizeBias>( context); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 643843821c13..eee37d6fcce2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -277,6 +277,7 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "ConvTranspose2DQInt8_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", @@ -373,6 +374,7 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", @@ -543,6 +545,7 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", @@ -2147,6 +2150,7 @@ "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "ConvTranspose2DQInt8_basic", } @@ -2298,6 +2302,7 @@ "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", "Conv3dModule_basic", @@ -2851,6 +2856,7 @@ "ContainsIntList_True", "Conv1dModule_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv3dModule_basic", @@ -3637,6 +3643,7 @@ "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index b157f91efc11..af8bea091d08 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1157,7 +1157,8 @@ def ConvTbcModule_basic(module, tu: TestUtils): class Conv2dQInt8Module(torch.nn.Module): - def __init__(self): + def __init__(self, groups=1): + self.groups = groups super().__init__() @export @@ -1186,7 +1187,7 @@ def forward(self, inputVec, weight, bias): stride=[1, 1], padding=[0, 0], dilation=[1, 1], - groups=1, + groups=self.groups, ) @@ -1198,13 +1199,12 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) -N = 10 -Cin = 5 -Cout = 7 -Hin = 10 -Win = 8 -Hker = 3 -Wker = 2 +@register_test_case(module_factory=lambda: Conv2dQInt8Module(groups=2)) +def Conv2dQInt8Module_grouped(module, tu: TestUtils): + inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8) + weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8) + bias = torch.rand(6) + module.forward(inputVec, weight, bias) class ConvTranspose2DQInt8Module(torch.nn.Module): @@ -1244,6 +1244,13 @@ def forward(self, input, weight, bias): @register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module()) def ConvTranspose2DQInt8_basic(module, tu: TestUtils): + N = 10 + Cin = 5 + Cout = 7 + Hin = 10 + Win = 8 + Hker = 3 + Wker = 2 module.forward( tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),