Skip to content

Commit

Permalink
[TorchToLinalg] add support for quantized group conv (#3341)
Browse files Browse the repository at this point in the history
This addresses 7 of the model failures I'm seeing in the test suite. See
[Shark-Turbine issue
#566](nod-ai/SHARK-ModelDev#566).

Need the op ```linalg.conv_2d_ngchw_gfchw_q``` to be added upstream
before merging this. See [llvm-project PR #92136
](llvm/llvm-project#92136).

A small additional expansion to operand quantization is included in this
patch to address a model failure that occurs when unblocking the
quantized group convolutions in one of these onnx models.
  • Loading branch information
zjgarvey authored Jun 3, 2024
1 parent 6382dbb commit 8995c90
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 25 deletions.
35 changes: 20 additions & 15 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
op, "lhs and rhs of convolution must either be both int or fp");
}

if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
if (inputZp && !isa<Torch::NoneType>(bias.getType())) {
auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
if (!biasDTy.isInteger(32)) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -1123,7 +1123,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
// - grouped 1d-3d
// - grouped 1d-3d (quantized)
// - ungrouped 1d-3d
if (groupSize == 1 && !inputZp && !weightZp) {
if (groupSize == 1 && !inputZp) {
switch (numSpatialDims) {
case 1:
conv = rewriter
Expand Down Expand Up @@ -1164,7 +1164,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (groupSize == 1 && inputZp && weightZp) {
if (groupSize == 1 && inputZp) {
// The quantized version uses a different channel ordering so we need to
// permute the tensors in order to use the existing path. We should
// eventually directly support this channel ordering.
Expand Down Expand Up @@ -1224,10 +1224,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (inputZp || weightZp)
return rewriter.notifyMatchFailure(
op, "unimplemented: quantized grouped convolutions");

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");
Expand All @@ -1238,7 +1234,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
auto weightShape = makeShapeTorchCompatible(
cast<RankedTensorType>(weight.getType()).getShape());
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1 && !inputZp) {
// Collapse weight shape
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
SmallVector<int64_t> collapsedShape{
Expand Down Expand Up @@ -1325,13 +1321,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
auto expandOutputTensor = expandGroups(outputTensor, 1);

// TODO: add 1D and 3D case
conv = rewriter
.create<linalg::Conv2DNgchwGfchwOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weightExpanded},
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);

if (!inputZp) {
conv = rewriter
.create<linalg::Conv2DNgchwGfchwOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weightExpanded},
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);
} else {
conv = rewriter
.create<linalg::Conv2DNgchwGfchwQOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weightExpanded, inputZp,
weightZp},
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);
}
conv = rewriter.create<tensor::CollapseShapeOp>(
loc, outputTensor.getType(), conv,
expandOutputTensor.getReassociationIndices());
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
QuantizeOperandsPastCommutingOps<AtenMmOp, 2>,
QuantizeOperandsPastCommutingOps<AtenMmOp, 4>,
QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>,
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
context);
Expand Down
7 changes: 7 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"ConvTranspose2DQInt8_basic",
# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",
Expand Down Expand Up @@ -373,6 +374,7 @@
"ContainsIntList_False",
"ContainsIntList_True",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
Expand Down Expand Up @@ -543,6 +545,7 @@
"ContainsIntList_False",
"ContainsIntList_True",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
Expand Down Expand Up @@ -2147,6 +2150,7 @@
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"ConvTranspose2DQInt8_basic",
}

Expand Down Expand Up @@ -2298,6 +2302,7 @@
"Conv2dModule_basic",
"Conv2dNoPaddingModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingModule_basic",
"Conv3dModule_basic",
Expand Down Expand Up @@ -2851,6 +2856,7 @@
"ContainsIntList_True",
"Conv1dModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Conv3dModule_basic",
Expand Down Expand Up @@ -3637,6 +3643,7 @@
"Conv2dModule_basic",
"Conv2dNoPaddingModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
Expand Down
25 changes: 16 additions & 9 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,8 @@ def ConvTbcModule_basic(module, tu: TestUtils):


class Conv2dQInt8Module(torch.nn.Module):
def __init__(self):
def __init__(self, groups=1):
self.groups = groups
super().__init__()

@export
Expand Down Expand Up @@ -1186,7 +1187,7 @@ def forward(self, inputVec, weight, bias):
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
groups=1,
groups=self.groups,
)


Expand All @@ -1198,13 +1199,12 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
module.forward(inputVec, weight, bias)


N = 10
Cin = 5
Cout = 7
Hin = 10
Win = 8
Hker = 3
Wker = 2
@register_test_case(module_factory=lambda: Conv2dQInt8Module(groups=2))
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
bias = torch.rand(6)
module.forward(inputVec, weight, bias)


class ConvTranspose2DQInt8Module(torch.nn.Module):
Expand Down Expand Up @@ -1244,6 +1244,13 @@ def forward(self, input, weight, bias):

@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
N = 10
Cin = 5
Cout = 7
Hin = 10
Win = 8
Hker = 3
Wker = 2
module.forward(
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
Expand Down

0 comments on commit 8995c90

Please sign in to comment.