-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[kernel] added half2 specialization for layernorm kernel #139
Open
dongxianzhe
wants to merge
5
commits into
vectorch-ai:main
Choose a base branch
from
dongxianzhe:op/layernorm_kernel
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
bd5b91f
[op] optimize layernorm kernel for half2 type
14a99f7
[ut] add layernorm kernel unitest
6ff1738
use gtest library rewrite layernorm kernel unitest
bc9f7e2
added layernorm kernel half2 unit test using gtest library
faaec07
[refactor] use torch::tensor to allocate memory in layernorm kernel u…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
#include <torch/torch.h> | ||
|
||
#include "dispatch.h" | ||
#include "layernorm_kernels.h" | ||
#include "reduce_kernel_utils.cuh" | ||
|
||
namespace llm::kernel { | ||
|
@@ -173,6 +174,61 @@ __global__ void layer_norm_kernel(T* __restrict__ out, | |
} | ||
} | ||
|
||
// equation: x -> (x - E[x]) / sqrt(Var[x] + eps) * w + b | ||
// The mean and standard-deviation are calculated over the last dimension | ||
template <> | ||
__global__ void layer_norm_kernel<half2>(half2* __restrict__ out, | ||
const half2* __restrict__ input, | ||
const half2* __restrict__ weight, | ||
const half2* __restrict__ bias, | ||
const float epsilon, | ||
int64_t n) { | ||
const int tidx = threadIdx.x; | ||
const int bidx = blockIdx.x; | ||
|
||
__shared__ half s_mean; | ||
__shared__ half s_variance; | ||
half2 mean = make_half2(__float2half(0.0f), __float2half(0.0f)); | ||
half2 variance = make_half2(__float2half(0.0f), __float2half(0.0f)); | ||
|
||
// calculate mean of the input. | ||
for (int i = tidx; i < n; i += blockDim.x) { | ||
const int idx = bidx * n + i; | ||
mean = __hadd2(mean, __ldg(&input[idx])); | ||
} | ||
mean = block_reduce_sum<half2>(mean); | ||
if (tidx == 0) { | ||
s_mean = __hdiv(__hadd(mean.x, mean.y), __float2half((float)n * 2)); | ||
} | ||
__syncthreads(); | ||
|
||
// calculate variance of the input. | ||
for (int i = tidx; i < n; i += blockDim.x) { | ||
const half2 x = __hsub2(input[bidx * n + i], make_half2(s_mean, s_mean)); | ||
variance = __hadd2(variance, __hmul2(x, x)); | ||
} | ||
variance = block_reduce_sum<half2>(variance); | ||
if (tidx == 0) { | ||
s_variance = __hadd(variance.x, variance.y); | ||
s_variance = __hdiv(s_variance, __float2half((float)n * 2)); | ||
s_variance = __hadd(s_variance, __float2half(epsilon)); | ||
s_variance = hrsqrt(s_variance); | ||
} | ||
__syncthreads(); | ||
|
||
for (int i = tidx; i < n; i += blockDim.x) { | ||
const int idx = bidx * n + i; | ||
half2 local_out = __ldg(&input[idx]); | ||
local_out = __hsub2(local_out, make_half2(s_mean, s_mean)); | ||
local_out = __hmul2(local_out, make_half2(s_variance, s_variance)); | ||
local_out = __hmul2(local_out, __ldg(&weight[i])); | ||
if (bias != nullptr) { | ||
local_out = __hadd2(local_out, __ldg(&bias[i])); | ||
} | ||
out[idx] = local_out; | ||
} | ||
} | ||
|
||
void layer_norm(torch::Tensor& out, | ||
torch::Tensor input, | ||
torch::Tensor weight, | ||
|
@@ -197,4 +253,54 @@ void layer_norm(torch::Tensor& out, | |
}); | ||
} | ||
|
||
} // namespace llm::kernel | ||
template <typename T> | ||
void invoke_layernorm_kernel(T* out, | ||
const T* input, | ||
const T* weight, | ||
const T* bias, | ||
const float epsilon, | ||
int m, | ||
int n) { | ||
layer_norm_kernel<T><<<m, n>>>(out, input, weight, bias, epsilon, n); | ||
} | ||
|
||
template <> | ||
void invoke_layernorm_kernel<half2>(half2* out, | ||
const half2* input, | ||
const half2* weight, | ||
const half2* bias, | ||
const float epsilon, | ||
int m, | ||
int n) { | ||
layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n); | ||
} | ||
template <> | ||
void invoke_layernorm_kernel<float>(float* out, | ||
const float* input, | ||
const float* weight, | ||
const float* bias, | ||
const float epsilon, | ||
int m, | ||
int n) { | ||
layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n); | ||
} | ||
|
||
template <> | ||
void invoke_layernorm_kernel<half>(half* out, | ||
const half* input, | ||
const half* weight, | ||
const half* bias, | ||
const float epsilon, | ||
int m, | ||
int n) { | ||
int half_n = n / 2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if n % 2 != 0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds you didn't cover this in unittest. |
||
half2* out_ptr = (half2*)out; | ||
const half2* input_ptr = (const half2*)input; | ||
const half2* weight_ptr = (const half2*)weight; | ||
const half2* bias_ptr = (const half2*)bias; | ||
|
||
dim3 block(std::min(half_n, 1024)); | ||
layer_norm_kernel<half2> | ||
<<<m, block>>>(out_ptr, input_ptr, weight_ptr, bias_ptr, epsilon, half_n); | ||
} | ||
} // namespace llm::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#include <cuda_fp16.h> | ||
#include <gtest/gtest.h> | ||
#include <torch/nn/functional.h> | ||
|
||
#include <cstdio> | ||
|
||
#include "layernorm_kernels.h" | ||
|
||
TEST(NormalizationKernelTest, LayernormFloatTest) { | ||
float epsilon = 1e-6; | ||
int m = 32; | ||
int n = 512; | ||
|
||
auto out = torch::zeros({m, n}, torch::TensorOptions().device(torch::kCUDA)); | ||
auto input = | ||
torch::randn({m, n}, torch::TensorOptions().device(torch::kCUDA)); | ||
auto weight = torch::randn({n}, torch::TensorOptions().device(torch::kCUDA)); | ||
auto bias = torch::randn({n}, torch::TensorOptions().device(torch::kCUDA)); | ||
auto desired_out = torch::nn::functional::layer_norm( | ||
input, | ||
torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( | ||
bias)); | ||
|
||
llm::kernel::layer_norm(out, input, weight, bias, epsilon); | ||
|
||
EXPECT_TRUE(torch::allclose(out, desired_out, 1e-3, 1e-5)); | ||
} | ||
|
||
TEST(NormalizationKernelTest, LayernormHalfTest) { | ||
float epsilon = 1e-6; | ||
int m = 4; | ||
int n = 512; | ||
|
||
auto out = torch::zeros( | ||
{m, n}, | ||
torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); | ||
auto input = torch::randn( | ||
{m, n}, | ||
torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); | ||
auto weight = torch::randn( | ||
{n}, | ||
torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); | ||
auto bias = torch::randn( | ||
{n}, | ||
torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); | ||
auto desired_out = torch::nn::functional::layer_norm( | ||
input, | ||
torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( | ||
bias)); | ||
|
||
llm::kernel::layer_norm(out, input, weight, bias, epsilon); | ||
|
||
EXPECT_TRUE(torch::allclose(out, desired_out, 0.05, 1e-3)); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds this template specializations are optional since they are covered by the general template. no?