-
Notifications
You must be signed in to change notification settings - Fork 403
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2024-08-30 nightly release (ff4a736)
- Loading branch information
pytorchbot
committed
Aug 30, 2024
1 parent
cf55211
commit ac2248f
Showing
54 changed files
with
1,670 additions
and
277 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
c42ac54d9e817bf0a0366eb78e6c8beba4d5eff5 | ||
e4cd76cf8283c8ddbf95674b020fbfcff467cb4b |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#version 450 core | ||
|
||
#include "indexing_utils.h" | ||
|
||
#define PRECISION ${PRECISION} | ||
|
||
#define FLOAT_T ${buffer_scalar_type(DTYPE)} | ||
|
||
${define_active_storage_type(STORAGE)} | ||
|
||
${define_required_extensions(DTYPE)} | ||
${define_required_extensions("int8")} | ||
|
||
layout(std430) buffer; | ||
|
||
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} | ||
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} | ||
${layout_declare_tensor(2, "r", "t_mat2", "int8", STORAGE)} | ||
${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)} | ||
|
||
${layout_declare_ubo(4, "ivec4", "out_sizes")} | ||
${layout_declare_ubo(5, "ivec4", "out_strides")} | ||
${layout_declare_ubo(6, "ivec4", "mat1_strides")} | ||
${layout_declare_ubo(7, "ivec4", "mat2_sizes")} | ||
${layout_declare_ubo(8, "ivec4", "mat2_strides")} | ||
${layout_declare_ubo(9, "ivec4", "scales_strides")} | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
layout(constant_id = 3) const int group_size = 1; | ||
|
||
void main() { | ||
|
||
const ivec4 out_pos = ivec4( | ||
gl_GlobalInvocationID.x, // n = 0..N-1 | ||
gl_GlobalInvocationID.y, // m = 0..M-1 | ||
gl_GlobalInvocationID.z % out_sizes.z, | ||
gl_GlobalInvocationID.z / out_sizes.z); | ||
|
||
if (any(greaterThanEqual(out_pos, out_sizes))) { | ||
return; | ||
} | ||
|
||
const uint K = mat2_sizes.x * 2; | ||
const uint N = mat2_sizes.y; | ||
const uint n = out_pos.x; | ||
const uint m = out_pos.y; | ||
const uint k_block = (K + group_size - 1) / group_size; | ||
const uint mask = uint(0x0f); | ||
ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w); | ||
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w); | ||
ivec4 scale_pos = ivec4(0, n, 0, out_pos.w); | ||
ivec4 zero_pos = ivec4(0, n, 1, out_pos.w); | ||
|
||
float rc = 0.0; | ||
int k = 0; | ||
|
||
for (int kb = 0; kb < k_block; kb++) { | ||
scale_pos.x = kb; | ||
const int scale_id = to_buffer_id(scale_pos, scales_strides); | ||
const float scale = float(t_scales_and_zeros[scale_id]); | ||
|
||
zero_pos.x = kb; | ||
const int zero_id = to_buffer_id(zero_pos, scales_strides); | ||
const float zero = float(t_scales_and_zeros[zero_id]) - scale * 8.0; | ||
|
||
for(uint idx = 0; idx < group_size && k < K; idx++, k++) { | ||
mat1_pos.x = k; | ||
const int mat1_id = to_buffer_id(mat1_pos, mat1_strides); | ||
const float mat1_val = float(t_mat1[mat1_id]); | ||
|
||
mat2_pos.x = k / 2; | ||
const int mat2_id = to_buffer_id(mat2_pos, mat2_strides); | ||
// Bitwise op treats sign bit from int8 as a value bit instead, | ||
// since there is no uint8_t datatype | ||
uint mat2_val = (t_mat2[mat2_id] & 0xFF); | ||
mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4); | ||
|
||
rc += mat1_val * (scale * float(mat2_val) + zero); | ||
} | ||
} | ||
|
||
const int out_id = to_buffer_id(out_pos, out_strides); | ||
t_out[out_id] = FLOAT_T(rc); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
q_4w_linear: | ||
parameter_names_with_default_values: | ||
DTYPE: float | ||
STORAGE: buffer | ||
generate_variant_forall: | ||
DTYPE: | ||
- VALUE: float | ||
- VALUE: half | ||
shader_variants: | ||
- NAME: q_4w_linear |
158 changes: 158 additions & 0 deletions
158
backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h> | ||
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h> | ||
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h> | ||
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h> | ||
|
||
namespace vkcompute { | ||
|
||
void check_q_matmul_args( | ||
ComputeGraph& graph, | ||
const ValueRef mat1, | ||
const ValueRef mat2_data, | ||
const ValueRef group_size_data, | ||
const ValueRef scales_and_zeros, | ||
const ValueRef out) { | ||
const std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1); | ||
const std::vector<int64_t> mat2_sizes = graph.sizes_of(mat2_data); | ||
const std::vector<int64_t> scales_and_zeros_sizes = | ||
graph.sizes_of(scales_and_zeros); | ||
|
||
const uint32_t group_size = graph.extract_scalar<uint32_t>(group_size_data); | ||
|
||
VK_CHECK_COND(mat1_sizes.size() == 2); | ||
VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); | ||
|
||
VK_CHECK_COND(graph.memory_layout_of(mat1) == graph.memory_layout_of(out)); | ||
|
||
const int mat1_K = utils::val_at(-1, mat1_sizes); | ||
const int mat2_K = utils::val_at(-1, mat2_sizes) * 2; | ||
const int N = utils::val_at(-2, mat2_sizes); | ||
|
||
VK_CHECK_COND(mat1_K == mat2_K); | ||
|
||
VK_CHECK_COND(mat2_K % group_size == 0); | ||
|
||
const uint32_t k_groups = mat2_K / group_size; | ||
|
||
VK_CHECK_COND(scales_and_zeros_sizes.size() == 3); | ||
VK_CHECK_COND(utils::val_at(-1, scales_and_zeros_sizes) == k_groups); | ||
VK_CHECK_COND(utils::val_at(-2, scales_and_zeros_sizes) == N); | ||
VK_CHECK_COND(utils::val_at(-3, scales_and_zeros_sizes) == 2); | ||
|
||
// Match https://fburl.com/code/6ostkknm | ||
std::vector<uint32_t> valid_group_sizes = {32, 64, 128, 256}; | ||
|
||
bool is_valid_group_size = false; | ||
for (auto valid_group_size : valid_group_sizes) { | ||
if (group_size == valid_group_size) { | ||
is_valid_group_size = true; | ||
break; | ||
} | ||
} | ||
|
||
VK_CHECK_COND(is_valid_group_size); | ||
} | ||
|
||
void resize_q_matmul_node( | ||
ComputeGraph* graph, | ||
const std::vector<ArgGroup>& args, | ||
const std::vector<ValueRef>& extra_args) { | ||
(void)extra_args; | ||
|
||
vTensorPtr out = graph->get_tensor(args[0].refs[0]); | ||
vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); | ||
vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); | ||
|
||
const int out_cols = utils::val_at(-2, mat1->sizes()); | ||
const int out_rows = utils::val_at(-2, mat2->sizes()); | ||
|
||
std::vector<int64_t> new_out_sizes(3); | ||
if (mat1->sizes().size() == 2) { | ||
new_out_sizes.resize(2); | ||
new_out_sizes.at(0) = out_cols; | ||
new_out_sizes.at(1) = out_rows; | ||
} else { | ||
new_out_sizes.at(0) = mat1->sizes().at(0); | ||
new_out_sizes.at(1) = out_cols; | ||
new_out_sizes.at(2) = out_rows; | ||
} | ||
|
||
out->virtual_resize(new_out_sizes); | ||
} | ||
|
||
void add_q_matmul_node( | ||
ComputeGraph& graph, | ||
const ValueRef mat1, | ||
const ValueRef mat2_data, | ||
const ValueRef group_size, | ||
const ValueRef scales_and_zeros_data, | ||
const ValueRef out) { | ||
ValueRef mat2 = | ||
prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked); | ||
ValueRef scales_and_zeros = | ||
prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked); | ||
|
||
std::string kernel_name = "q_4w_linear"; | ||
|
||
add_dtype_suffix(kernel_name, graph.dtype_of(out)); | ||
|
||
const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size); | ||
|
||
vkapi::ParamsBindList ubos({}); | ||
ubos.append(graph.sizes_ubo(out)); | ||
ubos.append(graph.strides_ubo(out)); | ||
ubos.append(graph.strides_ubo(mat1)); | ||
ubos.append(graph.sizes_ubo(mat2)); | ||
ubos.append(graph.strides_ubo(mat2)); | ||
ubos.append(graph.strides_ubo(scales_and_zeros)); | ||
|
||
auto out_sizes = graph.sizes_of(out); | ||
uint32_t N = utils::val_at(-1, out_sizes); | ||
uint32_t M = utils::val_at(-2, out_sizes); | ||
|
||
utils::uvec3 global_wg_size = {N, M, 1}; | ||
|
||
utils::uvec3 local_wg_size = adaptive_work_group_size(global_wg_size); | ||
|
||
graph.execute_nodes().emplace_back(new ExecuteNode( | ||
graph, | ||
VK_KERNEL_FROM_STR(kernel_name), | ||
global_wg_size, | ||
local_wg_size, | ||
// Inputs and Outputs | ||
{{out, vkapi::MemoryAccessType::WRITE}, | ||
{{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}}, | ||
// Shader params buffers | ||
ubos, | ||
// Specialization Constants | ||
{SV(group_size_val)}, | ||
// Resizing Logic | ||
resize_q_matmul_node, | ||
{})); | ||
} | ||
|
||
void int4pack_mm(ComputeGraph& graph, const std::vector<ValueRef>& args) { | ||
check_q_matmul_args(graph, args[0], args[1], args[2], args[3], args[4]); | ||
return add_q_matmul_node( | ||
graph, | ||
args[0], // mat1 | ||
args[1], // mat2 | ||
args[2], // group_size | ||
args[3], // scales_and_zeros | ||
args[4] // out | ||
); | ||
} | ||
|
||
REGISTER_OPERATORS { | ||
VK_REGISTER_OP(aten._weight_int4pack_mm.default, int4pack_mm); | ||
} | ||
|
||
} // namespace vkcompute |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.