Skip to content

Commit

Permalink
innerproduct allow 1 height gemm (#4730)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored May 17, 2023
1 parent 1b4a8fd commit f893d24
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 29 deletions.
6 changes: 3 additions & 3 deletions src/layer/arm/innerproduct_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Optio

const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -889,7 +889,7 @@ int InnerProduct_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const
{
const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -1295,7 +1295,7 @@ int InnerProduct_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, co
quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q);
}

if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1)
if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input)
{
// gemm
Mat bottom_blob_int8_unpacked;
Expand Down
2 changes: 1 addition & 1 deletion src/layer/arm/innerproduct_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ int InnerProduct_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, cons
{
const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down
2 changes: 1 addition & 1 deletion src/layer/arm/innerproduct_arm_vfpv4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ int InnerProduct_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const
{
const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down
4 changes: 2 additions & 2 deletions src/layer/innerproduct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ int InnerProduct::forward(const Mat& bottom_blob, Mat& top_blob, const Option& o
size_t elemsize = bottom_blob.elemsize;
int size = w * h;

if (bottom_blob.dims == 2 && w == num_input && h > 1)
if (bottom_blob.dims == 2 && w == num_input)
{
// gemm
top_blob.create(num_output, h, elemsize, opt.blob_allocator);
Expand Down Expand Up @@ -201,7 +201,7 @@ int InnerProduct::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Opti
quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_g);
}

if (bottom_blob.dims == 2 && w == num_input && h > 1)
if (bottom_blob.dims == 2 && w == num_input)
{
// gemm
top_blob.create(num_output, h, 4u, opt.blob_allocator);
Expand Down
6 changes: 3 additions & 3 deletions src/layer/loongarch/innerproduct_loongarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ int InnerProduct_loongarch::forward(const Mat& bottom_blob, Mat& top_blob, const

const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -667,7 +667,7 @@ int InnerProduct_loongarch::forward_fp16s(const Mat& bottom_blob, Mat& top_blob,
{
const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -1168,7 +1168,7 @@ int InnerProduct_loongarch::forward_int8_loongarch(const Mat& bottom_blob, Mat&
quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q);
}

if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1)
if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input)
{
// gemm
Mat bottom_blob_int8_unpacked;
Expand Down
6 changes: 3 additions & 3 deletions src/layer/mips/innerproduct_mips.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ int InnerProduct_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Opti

const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -667,7 +667,7 @@ int InnerProduct_mips::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, cons
{
const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -1168,7 +1168,7 @@ int InnerProduct_mips::forward_int8_mips(const Mat& bottom_blob, Mat& top_blob,
quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q);
}

if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1)
if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input)
{
// gemm
Mat bottom_blob_int8_unpacked;
Expand Down
6 changes: 3 additions & 3 deletions src/layer/riscv/innerproduct_riscv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ int InnerProduct_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Opt

const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -577,7 +577,7 @@ int InnerProduct_riscv::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, con

const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -839,7 +839,7 @@ int InnerProduct_riscv::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, co

const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down
6 changes: 3 additions & 3 deletions src/layer/vulkan/innerproduct_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ int InnerProduct_vulkan::create_pipeline(const Option& _opt)
convert_packing(bias_data, bias_data_packed, out_elempack, opt);
}

if (shape.dims == 2 && shape.w == num_input && shape.h > 1)
if (shape.dims == 2 && shape.w == num_input)
{
// gemm
int elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1;
Expand Down Expand Up @@ -427,7 +427,7 @@ int InnerProduct_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCo
int in_elempack = opt.use_shader_pack8 && num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1;
int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -587,7 +587,7 @@ int InnerProduct_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_
int in_elempack = opt.use_shader_pack8 && num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1;
int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down
6 changes: 3 additions & 3 deletions src/layer/x86/innerproduct_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ int InnerProduct_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Optio

const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -190,7 +190,7 @@ int InnerProduct_x86::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const
{
const int num_input = weight_data_size / num_output;

if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
{
// gemm
int h = bottom_blob.h;
Expand Down Expand Up @@ -309,7 +309,7 @@ int InnerProduct_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, co
quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q);
}

if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1)
if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input)
{
// gemm
Mat bottom_blob_int8_unpacked;
Expand Down
2 changes: 2 additions & 0 deletions tests/test_innerproduct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ static int test_innerproduct_gemm(const ncnn::Mat& a, int outch, int bias)
static int test_innerproduct_4()
{
return 0
|| test_innerproduct_gemm(RandomMat(1, 1), 1, 1)
|| test_innerproduct_gemm(RandomMat(48, 1), 11, 1)
|| test_innerproduct_gemm(RandomMat(1, 5), 1, 1)
|| test_innerproduct_gemm(RandomMat(3, 2), 2, 0)
|| test_innerproduct_gemm(RandomMat(9, 8), 7, 1)
Expand Down
21 changes: 14 additions & 7 deletions tools/pnnx/tests/ncnn/test_nn_Linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):
self.linear_0 = nn.Linear(in_features=64, out_features=16, bias=False)
self.linear_1 = nn.Linear(in_features=16, out_features=3, bias=True)

def forward(self, x, y, z):
def forward(self, x, y, z, w):
x = self.linear_0(x)
x = self.linear_1(x)

Expand All @@ -33,7 +33,10 @@ def forward(self, x, y, z):
z = self.linear_0(z)
z = self.linear_1(z)
z = F.relu(z)
return x, y, z

w = self.linear_0(w)
w = self.linear_1(w)
return x, y, z, w

def test():
net = Model().half().float()
Expand All @@ -43,22 +46,26 @@ def test():
x = torch.rand(64)
y = torch.rand(12, 64)
z = torch.rand(1, 3, 12, 64)
w = torch.rand(1, 64)

a0, a1, a2 = net(x, y, z)
a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_nn_Linear.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_Linear.pt inputshape=[64],[12,64],[1,3,12,64]")
os.system("../../src/pnnx test_nn_Linear.pt inputshape=[64],[12,64],[1,3,12,64],[1,64]")

# ncnn inference
import test_nn_Linear_ncnn
b0, b1, b2 = test_nn_Linear_ncnn.test_inference()
b = test_nn_Linear_ncnn.test_inference()

return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) and torch.allclose(a2, b2, 1e-4, 1e-4)
for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
Expand Down

0 comments on commit f893d24

Please sign in to comment.