Skip to content

Commit

Permalink
[webgpu] Always use tile matmulnbits for block_size = 32 (#23140)
Browse files Browse the repository at this point in the history
### Description
After the optimization of prefill time with #23102, it seems that always
using the tile matmulnibits with block_size = 32 can bring better
performance even for discrete gpu for phi3 model.

Phi3 becomes 42.64 tokens/sec from 32.82 tokens/sec in easy mode on my
NV RTX 2000 GPU.
  • Loading branch information
qjia7 authored Dec 20, 2024
1 parent b4a6a0d commit 7c782f6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
8 changes: 3 additions & 5 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform);
const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);

if ((is_intel_ || tile_m_ > 1) && block_size_ == 32) {
if (block_size_ == 32) {
const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY();
const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data.
const uint32_t a_length_per_tile = tile_size / a.NumComponents();
Expand Down Expand Up @@ -408,14 +408,12 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
const uint32_t components_b = GetMaxComponents(blob_size_in_words);
uint32_t components = GetMaxComponents(N);

const bool is_intel = context.AdapterInfo().vendor == std::string_view{"intel"} &&
context.AdapterInfo().architecture == std::string_view{"gen-12lp"};
const bool has_zero_points = zero_points != nullptr;

// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
constexpr uint32_t output_number = 1;
const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1;
MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow<int>(components_b), has_zero_points, is_intel};
MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow<int>(components_b), has_zero_points};
if (M > kMinMForTileOptimization && block_size == 32) {
components = 1;
constexpr uint32_t workgroup_size = 64;
Expand All @@ -426,7 +424,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
(M + tile_m - 1) / tile_m,
batch_count);
program.CacheHint("T_M" + std::to_string(tile_m));
} else if (is_intel && block_size == 32) {
} else if (block_size == 32) {
components = 1;
constexpr uint32_t workgroup_size = 128;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
Expand Down
14 changes: 6 additions & 8 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ using namespace onnxruntime::webgpu;

class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
public:
MatMulNBitsProgram(uint32_t output_number, uint32_t block_size, uint32_t tile_m, int components_b, bool has_zero_points, bool is_intel) : Program{"MatMulNBits"},
output_number_{output_number},
block_size_{block_size},
tile_m_{tile_m},
components_b_{components_b},
has_zero_points_{has_zero_points},
is_intel_{is_intel} {
MatMulNBitsProgram(uint32_t output_number, uint32_t block_size, uint32_t tile_m, int components_b, bool has_zero_points) : Program{"MatMulNBits"},
output_number_{output_number},
block_size_{block_size},
tile_m_{tile_m},
components_b_{components_b},
has_zero_points_{has_zero_points} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -32,7 +31,6 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
uint32_t tile_m_;
int components_b_;
bool has_zero_points_;
bool is_intel_;
};

class MatMulNBits final : public WebGpuKernel {
Expand Down

0 comments on commit 7c782f6

Please sign in to comment.