Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU/OpenCL] Initial version of Rotary Embedding Kernel for GPU and generalization via Attention Interface @ open sesame 10/18 11:30 #2721

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <addition_layer_cl.h>
#include <attention_kernel_strings.h>
#include <blas_kernel_strings.h>
#include <cl_context.h>
#include <concat_cl.h>
Expand Down Expand Up @@ -149,6 +150,21 @@ void ClContext::initBlasClKernels() {
blas_kernels_initialized = true;
}

void ClContext::initAttentionClKernels() {
if (attention_kernels_initialized) {
ml_logi("ClContext: Default attention kernels already registered and "
"initialized");
return;
}

registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl");

#ifdef ENABLE_FP16
registerClKernel(rotary_emb_cl_kernel_fp16_, "rotary_emb_cl_fp16");
#endif
attention_kernels_initialized = true;
}

const ClContext::SharedPtrClKernel
ClContext::registerClKernel(std::string kernel_string,
std::string kernel_name) {
Expand Down
8 changes: 8 additions & 0 deletions nntrainer/cl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ class ClContext {
*/
void initBlasClKernels();

/**
* @brief Initialize and register all attention OpenCl kernels
*/
void initAttentionClKernels();

/**
* @brief destructor to release opencl commandQueue
*/
Expand All @@ -229,6 +234,9 @@ class ClContext {
// flag to check default blas kernels registered or not
bool blas_kernels_initialized = false;

// flag to check default attention kernels registered or not
bool attention_kernels_initialized = false;

FactoryMap<nntrainer::Layer> factory_map;

template <typename Args, typename T> struct isSupportedHelper;
Expand Down
138 changes: 138 additions & 0 deletions nntrainer/tensor/cl_operations/attention_kernel_interface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Yash Singh <[email protected]>
*
* @file attention_kernel_interface.cpp
* @date 28 August 2024
* @brief Interface for attention OpenCL kernels
* @see https://github.com/nnstreamer/nntrainer
* @author Yash Singh <[email protected]>
* @bug No known bugs except for NYI items
*
*/

#include <attention_kernel_interface.h>
#include <attention_kernels.h>

namespace nntrainer {
/**
* @brief compute frequency for rotary embedding
* @param[in] dim hidden dim size
* @param[in] seq_len sequency length
* @param[out] freqs_cos cosine of the frequencies
* @param[out] freqs_sin sine of the frequencies
* @param[out] freqs base frequencies array to be used in the future computation
* @param[in] theta rotary angle
*/
void precompute_freqs(unsigned int dim, unsigned int seq_len,
std::vector<std::vector<float>> &freqs_cos,
std::vector<std::vector<float>> &freqs_sin,
std::vector<float> &freqs, float theta = 10000.0) {
unsigned int half_ = dim / 2;
for (unsigned int i = 0; i < half_; ++i) {
freqs.push_back(1.0 / (std::pow(theta, (2 * i) / static_cast<float>(dim))));
}

auto cos_vec = std::vector<std::vector<float>>();
cos_vec.assign(seq_len, std::vector<float>(dim, 0));

auto sin_vec = std::vector<std::vector<float>>();
sin_vec.assign(seq_len, std::vector<float>(dim, 0));

for (unsigned int i = 0; i < seq_len; ++i) {
for (unsigned int j = 0; j < half_; ++j) {
float angle = i * freqs[j];
cos_vec[i][j] = std::cos(angle);
cos_vec[i][j + half_] = std::cos(angle); // repeated 2 times

sin_vec[i][j] = std::sin(angle);
sin_vec[i][j + half_] = std::sin(angle); // repeated 2 times
}
}
freqs_cos = cos_vec;
freqs_sin = sin_vec;
}

/**
* @brief apply rotary embedding
* @param[in] in input tensor
* @param[in] dim hidden dim size
* @param[in] from sequence order
* @param[in] max_timestep maximum timestep
*
* @todo Calling precompute_freqs in finalize to reduce code redundancy.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing doc?

Suggested change
*/
* @param[in] context layer context to get the resource manager and queue id
*/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll update this as well.

void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
unsigned int max_timestep) {
nntrainer::Tensor out(in.getDim());
float value = 0.0f;
float transformed_value = 0.0f;
unsigned int half_ = dim / 2;

std::vector<std::vector<float>> freqs_cos = {};
Copy link
Contributor

@EunjuYang EunjuYang Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that the function apply_rotary_emb_cl generates freqs_sin and freqs_cos every time it runs, which seams to work unintentionally. As far as I understand, the value of dim is typically considered constant across transformer block (please correct me if I misunderstood). Hence, the purpose of this code was to generate a set of frequencies that can be commonly utilized. Nevertheless, in this particular implementation, it does not function as expected and ends up performing repetitive calculations. What are your thoughts regarding this issue?

Copy link
Contributor

@EunjuYang EunjuYang Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possible way to update is to take freqs_cos and freqs_sin as inputs, and call the precompute only when they are null or empty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Initially I also kept freqs_cos, freqs_sin, and freqs as global vector and was only updating it inside precompute_freqs if(freqs_cos.empty()), but after looking at the custom_multi_head_attention_layer.cpp in felice repo, dim is passed as

  apply_rotary_emb_tensor(projected_query_step, projected_query_dim_prop, _from);
  apply_rotary_emb_tensor(cache_key_step, projected_key_dim_prop, _from);

Thats why as per my understanding, I computed both the vectors repeatedly as per the dim.

And are you suggesting to pass freqs_cos anf freqs_sin as parameters in apply_rotary_emb_tensor_cl function or somewhere else ? If we pass both vectors as parameters then while calling precompute_freqs we can check whether they are empty or not, hence reducing the redundancy of the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct in that the dimension for the query/key might be different, but the block will be repeated (stacked several times!). Also, I think apply_rotary_emb_tensor_cl will be invoked whenever forwarding of the multi_head_attention layer is called. Thus, it is worth considering reducing unnecessary computation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making them as parameters in the apply_rotary_emb_tensor_cl function, I think it might be a good idea.

Apart from this, another solution we can do is to call precompute_freqs in the finalize of multi_head_attention instead of calling it inside apply_rotary_emb_tensor_cl, meaning
in finalize we can check if(freqs_cos.empty()) precompute_freqs();, and we can make freqs_cos anf freqs_sin as global vectors in the attention_kernel_interface.cpp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds reasonable to call precompute_freqs() when the finalize() is invoked ! 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a @todo in the latest commit for the above.

std::vector<std::vector<float>> freqs_sin = {};
std::vector<float> freqs;

precompute_freqs(dim, max_timestep, freqs_cos, freqs_sin, freqs);

std::vector<float> cos_;
std::vector<float> sin_;

if (from >= max_timestep) {
cos_.resize(dim);
sin_.resize(dim);

for (unsigned int i = 0; i < half_; ++i) {
float angle = from * freqs[i];
cos_[i] = std::cos(angle);
cos_[i + half_] = std::cos(angle); // repeated 2 times

sin_[i] = std::sin(angle);
sin_[i + half_] = std::sin(angle); // repeated 2 times
}
} else {
cos_.resize(max_timestep);
sin_.resize(max_timestep);
}

unsigned int input_batch_size, input_height, input_width, input_channels;
input_batch_size = in.batch();
input_height = in.height();
input_width = in.width();
input_channels = in.channel();

if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {

unsigned int in_size = in.size();
unsigned int out_size = out.size();
float *data = in.getData();
float *rdata = out.getData();

rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
input_batch_size, input_channels, input_height, input_width,
dim, from, max_timestep, in_size, out_size);

} else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16

unsigned int in_size = in.size();
unsigned int out_size = out.size();
_FP16 *data = in.getData<_FP16>();
_FP16 *rdata = out.getData<_FP16>();

rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
input_batch_size, input_channels, input_height, input_width,
dim, from, max_timestep, in_size, out_size);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
}

if (from >= max_timestep) {
cos_.clear();
sin_.clear();
}

in.copy(out);
}
} // namespace nntrainer
33 changes: 33 additions & 0 deletions nntrainer/tensor/cl_operations/attention_kernel_interface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Yash Singh <[email protected]>
*
* @file blas_kernel_interface.h
* @date 28 August 2024
* @brief Interface for attention OpenCL kernels
* @see https://github.com/nnstreamer/nntrainer
* @author Yash Singh <[email protected]>
* @bug No known bugs except for NYI items
*
*/

#ifndef __ATTENTION_KERNEL_INTERFACE_H__
#define __ATTENTION_KERNEL_INTERFACE_H__

#include <string>
#include <tensor.h>

namespace nntrainer {

/**
* @brief Rotary Embedding kernel
* @param[in] in input tensor
* @param[in] dim hidden dim size
* @param[in] from sequence order
* @param[in] max_timestep maximum timestep
*/
void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
unsigned int max_timestep);

} // namespace nntrainer
#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
133 changes: 133 additions & 0 deletions nntrainer/tensor/cl_operations/attention_kernel_strings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Yash Singh <[email protected]>
*
* @file attention_kernel_strings.h
* @date 8 October 2024
* @brief All attention OpenCL kernel strings
* @see https://github.com/nnstreamer/nntrainer
* @author Yash Singh <[email protected]>
* @bug No known bugs except for NYI items
*
*/

#ifndef __ATTENTION_KERNEL_STRINGS_H__
#define __ATTENTION_KERNEL_STRINGS_H__

#include <string>

namespace nntrainer {
static const std::string rotary_emb_cl_kernel_ = R"(

#pragma OPENCL EXTENSION cl_khr_fp16 : enable

__kernel void rotary_emb_cl(__global float *input,
__global float *output,
__global float *freqs_cos,
__global float *freqs_sin,
__global float *cos_,
__global float *sin_,
unsigned int batch,
unsigned int channel,
unsigned int height,
unsigned int width,
unsigned int dim,
unsigned int half_,
unsigned int max_timestep,
unsigned int from) {
__global float *cos_ptr = cos_;
__global float *sin_ptr = sin_;

float value = 0.0f;
float transformed_value = 0.0f;

unsigned int b = get_global_id(0);
unsigned int c = get_global_id(1);

if(b < batch && c < channel){
for (unsigned int h = 0; h < height; h++) {
if (from + h < max_timestep) {
unsigned idx = (from + h)*dim;
for(unsigned int i = idx; i < idx + dim; i++){
cos_ptr[i - idx] = freqs_cos[i];
sin_ptr[i - idx] = freqs_sin[i];
}
}

for (unsigned int w = 0; w < width; w = w + dim) {
for (unsigned int k = 0; k < dim; k++) {
unsigned int span = w + k;
value = input[b * channel * height * width + c * height * width + h * width + span];
if (k < half_) {
transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
} else {
transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
}
value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
output[b * channel * height * width + c * height * width + h * width + span] = value;
}
}
}
}
}
)";

#ifdef ENABLE_FP16
static const std::string rotary_emb_cl_kernel_fp16_ = R"(

#pragma OPENCL EXTENSION cl_khr_fp16 : enable

__kernel void rotary_emb_cl_fp16(__global half *input,
__global half *output,
__global float *freqs_cos,
__global float *freqs_sin,
__global float *cos_,
__global float *sin_,
unsigned int batch,
unsigned int channel,
unsigned int height,
unsigned int width,
unsigned int dim,
unsigned int half_,
unsigned int max_timestep,
unsigned int from) {
__global float *cos_ptr = cos_;
__global float *sin_ptr = sin_;

float value = 0.0f;
float transformed_value = 0.0f;

unsigned int b = get_global_id(0);
unsigned int c = get_global_id(1);

if(b < batch && c < channel){
for (unsigned int h = 0; h < height; h++) {
if (from + h < max_timestep) {
unsigned idx = (from + h)*dim;
for(int i = idx; i < idx + dim; i++ ){
cos_ptr[i - idx] = freqs_cos[i];
sin_ptr[i - idx] = freqs_sin[i];
}
}

for (unsigned int w = 0; w < width; w = w + dim) {
for (unsigned int k = 0; k < dim; k++) {
unsigned int span = w + k;
value = (float)input[b * channel * height * width + c * height * width + h * width + span];
if (k < half_) {
transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
} else {
transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
}
value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
}
}
}
}
}
)";

#endif
} // namespace nntrainer
#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
Loading
Loading