From f893d2440d2a4efc72d88a2e4bab885aebc64d68 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 17 May 2023 10:44:26 +0800 Subject: [PATCH] innerproduct allow 1 height gemm (#4730) --- src/layer/arm/innerproduct_arm.cpp | 6 +++--- src/layer/arm/innerproduct_arm_asimdhp.cpp | 2 +- src/layer/arm/innerproduct_arm_vfpv4.cpp | 2 +- src/layer/innerproduct.cpp | 4 ++-- .../loongarch/innerproduct_loongarch.cpp | 6 +++--- src/layer/mips/innerproduct_mips.cpp | 6 +++--- src/layer/riscv/innerproduct_riscv.cpp | 6 +++--- src/layer/vulkan/innerproduct_vulkan.cpp | 6 +++--- src/layer/x86/innerproduct_x86.cpp | 6 +++--- tests/test_innerproduct.cpp | 2 ++ tools/pnnx/tests/ncnn/test_nn_Linear.py | 21 ++++++++++++------- 11 files changed, 38 insertions(+), 29 deletions(-) diff --git a/src/layer/arm/innerproduct_arm.cpp b/src/layer/arm/innerproduct_arm.cpp index ea3d874d8f31..6a63b2093df6 100644 --- a/src/layer/arm/innerproduct_arm.cpp +++ b/src/layer/arm/innerproduct_arm.cpp @@ -177,7 +177,7 @@ int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Optio const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -889,7 +889,7 @@ int InnerProduct_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const { const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -1295,7 +1295,7 @@ int InnerProduct_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, co quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q); } - if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1) + if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input) { // gemm Mat bottom_blob_int8_unpacked; diff --git a/src/layer/arm/innerproduct_arm_asimdhp.cpp b/src/layer/arm/innerproduct_arm_asimdhp.cpp index 90fb39aa8160..1dc5c990b4ea 100644 --- a/src/layer/arm/innerproduct_arm_asimdhp.cpp +++ b/src/layer/arm/innerproduct_arm_asimdhp.cpp @@ -53,7 +53,7 @@ int InnerProduct_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, cons { const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; diff --git a/src/layer/arm/innerproduct_arm_vfpv4.cpp b/src/layer/arm/innerproduct_arm_vfpv4.cpp index 6aaad4781a12..435fb883e508 100644 --- a/src/layer/arm/innerproduct_arm_vfpv4.cpp +++ b/src/layer/arm/innerproduct_arm_vfpv4.cpp @@ -53,7 +53,7 @@ int InnerProduct_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const { const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; diff --git a/src/layer/innerproduct.cpp b/src/layer/innerproduct.cpp index 26b28f352f8b..4cc22981c344 100644 --- a/src/layer/innerproduct.cpp +++ b/src/layer/innerproduct.cpp @@ -115,7 +115,7 @@ int InnerProduct::forward(const Mat& bottom_blob, Mat& top_blob, const Option& o size_t elemsize = bottom_blob.elemsize; int size = w * h; - if (bottom_blob.dims == 2 && w == num_input && h > 1) + if (bottom_blob.dims == 2 && w == num_input) { // gemm top_blob.create(num_output, h, elemsize, opt.blob_allocator); @@ -201,7 +201,7 @@ int InnerProduct::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Opti quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_g); } - if (bottom_blob.dims == 2 && w == num_input && h > 1) + if (bottom_blob.dims == 2 && w == num_input) { // gemm top_blob.create(num_output, h, 4u, opt.blob_allocator); diff --git a/src/layer/loongarch/innerproduct_loongarch.cpp b/src/layer/loongarch/innerproduct_loongarch.cpp index 3dd6ff35e232..34e908fc11ad 100644 --- a/src/layer/loongarch/innerproduct_loongarch.cpp +++ b/src/layer/loongarch/innerproduct_loongarch.cpp @@ -137,7 +137,7 @@ int InnerProduct_loongarch::forward(const Mat& bottom_blob, Mat& top_blob, const const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -667,7 +667,7 @@ int InnerProduct_loongarch::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, { const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -1168,7 +1168,7 @@ int InnerProduct_loongarch::forward_int8_loongarch(const Mat& bottom_blob, Mat& quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q); } - if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1) + if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input) { // gemm Mat bottom_blob_int8_unpacked; diff --git a/src/layer/mips/innerproduct_mips.cpp b/src/layer/mips/innerproduct_mips.cpp index 6a4181686b5d..b064a20e522d 100644 --- a/src/layer/mips/innerproduct_mips.cpp +++ b/src/layer/mips/innerproduct_mips.cpp @@ -137,7 +137,7 @@ int InnerProduct_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Opti const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -667,7 +667,7 @@ int InnerProduct_mips::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, cons { const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -1168,7 +1168,7 @@ int InnerProduct_mips::forward_int8_mips(const Mat& bottom_blob, Mat& top_blob, quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q); } - if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1) + if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input) { // gemm Mat bottom_blob_int8_unpacked; diff --git a/src/layer/riscv/innerproduct_riscv.cpp b/src/layer/riscv/innerproduct_riscv.cpp index 30dd74287776..ac7b3169708b 100644 --- a/src/layer/riscv/innerproduct_riscv.cpp +++ b/src/layer/riscv/innerproduct_riscv.cpp @@ -173,7 +173,7 @@ int InnerProduct_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Opt const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -577,7 +577,7 @@ int InnerProduct_riscv::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, con const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -839,7 +839,7 @@ int InnerProduct_riscv::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, co const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; diff --git a/src/layer/vulkan/innerproduct_vulkan.cpp b/src/layer/vulkan/innerproduct_vulkan.cpp index 9cb57e5d8e3b..06bf7b569430 100644 --- a/src/layer/vulkan/innerproduct_vulkan.cpp +++ b/src/layer/vulkan/innerproduct_vulkan.cpp @@ -79,7 +79,7 @@ int InnerProduct_vulkan::create_pipeline(const Option& _opt) convert_packing(bias_data, bias_data_packed, out_elempack, opt); } - if (shape.dims == 2 && shape.w == num_input && shape.h > 1) + if (shape.dims == 2 && shape.w == num_input) { // gemm int elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1; @@ -427,7 +427,7 @@ int InnerProduct_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCo int in_elempack = opt.use_shader_pack8 && num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1; int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -587,7 +587,7 @@ int InnerProduct_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_ int in_elempack = opt.use_shader_pack8 && num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1; int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; diff --git a/src/layer/x86/innerproduct_x86.cpp b/src/layer/x86/innerproduct_x86.cpp index 6dcc484e2991..67bf0cca5480 100644 --- a/src/layer/x86/innerproduct_x86.cpp +++ b/src/layer/x86/innerproduct_x86.cpp @@ -118,7 +118,7 @@ int InnerProduct_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Optio const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -190,7 +190,7 @@ int InnerProduct_x86::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const { const int num_input = weight_data_size / num_output; - if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) + if (bottom_blob.dims == 2 && bottom_blob.w == num_input) { // gemm int h = bottom_blob.h; @@ -309,7 +309,7 @@ int InnerProduct_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, co quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q); } - if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1) + if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input) { // gemm Mat bottom_blob_int8_unpacked; diff --git a/tests/test_innerproduct.cpp b/tests/test_innerproduct.cpp index a46328e44e32..a9ec260db68f 100644 --- a/tests/test_innerproduct.cpp +++ b/tests/test_innerproduct.cpp @@ -179,6 +179,8 @@ static int test_innerproduct_gemm(const ncnn::Mat& a, int outch, int bias) static int test_innerproduct_4() { return 0 + || test_innerproduct_gemm(RandomMat(1, 1), 1, 1) + || test_innerproduct_gemm(RandomMat(48, 1), 11, 1) || test_innerproduct_gemm(RandomMat(1, 5), 1, 1) || test_innerproduct_gemm(RandomMat(3, 2), 2, 0) || test_innerproduct_gemm(RandomMat(9, 8), 7, 1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Linear.py b/tools/pnnx/tests/ncnn/test_nn_Linear.py index b8106dccaa14..ea79c50db180 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Linear.py +++ b/tools/pnnx/tests/ncnn/test_nn_Linear.py @@ -23,7 +23,7 @@ def __init__(self): self.linear_0 = nn.Linear(in_features=64, out_features=16, bias=False) self.linear_1 = nn.Linear(in_features=16, out_features=3, bias=True) - def forward(self, x, y, z): + def forward(self, x, y, z, w): x = self.linear_0(x) x = self.linear_1(x) @@ -33,7 +33,10 @@ def forward(self, x, y, z): z = self.linear_0(z) z = self.linear_1(z) z = F.relu(z) - return x, y, z + + w = self.linear_0(w) + w = self.linear_1(w) + return x, y, z, w def test(): net = Model().half().float() @@ -43,22 +46,26 @@ def test(): x = torch.rand(64) y = torch.rand(12, 64) z = torch.rand(1, 3, 12, 64) + w = torch.rand(1, 64) - a0, a1, a2 = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_nn_Linear.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_Linear.pt inputshape=[64],[12,64],[1,3,12,64]") + os.system("../../src/pnnx test_nn_Linear.pt inputshape=[64],[12,64],[1,3,12,64],[1,64]") # ncnn inference import test_nn_Linear_ncnn - b0, b1, b2 = test_nn_Linear_ncnn.test_inference() + b = test_nn_Linear_ncnn.test_inference() - return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) and torch.allclose(a2, b2, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test():