-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[kernel] added half2 specialization for layernorm kernel #139
base: main
Are you sure you want to change the base?
Conversation
} | ||
|
||
template <> | ||
void invoke_layernorm_kernel<half2>(half2* out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds this template specializations are optional since they are covered by the general template. no?
const float epsilon, | ||
int m, | ||
int n) { | ||
int half_n = n / 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if n % 2 != 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds you didn't cover this in unittest.
float* dinput; | ||
float* dweight; | ||
float* dbias; | ||
cudaMalloc((void**)&dout, sizeof(float) * m * n); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use torch::tensor to allocate memory
torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( | ||
bias)); | ||
|
||
half* hout = (half*)malloc(m * n * sizeof(half)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here.
cudaMemcpy(dweight, hweight, sizeof(half) * n, cudaMemcpyHostToDevice); | ||
cudaMemcpy(dbias, hbias, sizeof(half) * n, cudaMemcpyHostToDevice); | ||
|
||
llm::kernel::invoke_layernorm_kernel<half>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just test llm::kernel::layer_norm instead but pass in different length of input to trigger different kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for adding the optimization. could you also add benchmark to show the improvements? thanks
e18e337
to
bc9f7e2
Compare
…nitest and just test llm::kernel::layer_norm
optimize layernorm kernel using half2 type
test layernorm kernel