From ab72f7358f385ba4d5237520a1dc731e7e03a869 Mon Sep 17 00:00:00 2001 From: Yash Singh Date: Wed, 28 Aug 2024 17:40:24 +0530 Subject: [PATCH 1/6] [GPU/OpenCL] Initial version of Rotary Embedding with OpenCL ops Added initial version of Rotary Embedding kernel for GPU. This includes both FP32 and FP16 implementation got GPU kernel. Signed-off-by: Yash Singh --- .../attention_kernel_interface.cpp | 136 +++++++++ .../attention_kernel_interface.h | 34 +++ .../cl_operations/attention_kernels.cpp | 273 +++++++++++++++++ .../tensor/cl_operations/attention_kernels.h | 96 ++++++ .../cl_operations/attention_kernels_fp16.cpp | 279 ++++++++++++++++++ nntrainer/tensor/cl_operations/meson.build | 4 + .../cl_operations/testing_rotarty_emb.cpp | 195 ++++++++++++ test/jni/Android.mk | 16 + .../unittest_attention_kernels_cl.cpp | 233 +++++++++++++++ 9 files changed, 1266 insertions(+) create mode 100644 nntrainer/tensor/cl_operations/attention_kernel_interface.cpp create mode 100644 nntrainer/tensor/cl_operations/attention_kernel_interface.h create mode 100644 nntrainer/tensor/cl_operations/attention_kernels.cpp create mode 100644 nntrainer/tensor/cl_operations/attention_kernels.h create mode 100644 nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp create mode 100644 nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp create mode 100644 test/unittest/unittest_attention_kernels_cl.cpp diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp new file mode 100644 index 0000000000..cf28840176 --- /dev/null +++ b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file attention_kernel_interface.cpp + * @date 28 August 2024 + * @brief Interface for attention OpenCL kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#include +#include + +namespace nntrainer { +/** + * @brief compute frequency for rotary embedding + * @param[in] dim hidden dim size + * @param[in] seq_len sequency length + * @param[out] freqs_cos cosine of the frequencies + * @param[out] freqs_sin sine of the frequencies + * @param[out] freqs base frequencies array to be used in the future computation + * @param[in] theta rotary angle + */ +void precompute_freqs(int dim, unsigned int seq_len, + std::vector> &freqs_cos, + std::vector> &freqs_sin, + std::vector &freqs, float theta = 10000.0) { + unsigned int half_ = dim / 2; + for (unsigned int i = 0; i < half_; ++i) { + freqs.push_back(1.0 / (std::pow(theta, (2 * i) / static_cast(dim)))); + } + + auto cos = std::vector>(); + cos.assign(seq_len, std::vector(dim, 0)); + + auto sin = std::vector>(); + sin.assign(seq_len, std::vector(dim, 0)); + + for (unsigned int i = 0; i < seq_len; ++i) { + for (unsigned int j = 0; j < half_; ++j) { + float angle = i * freqs[j]; + cos[i][j] = std::cos(angle); + cos[i][j + half_] = std::cos(angle); // repeated 2 times + + sin[i][j] = std::sin(angle); + sin[i][j + half_] = std::sin(angle); // repeated 2 times + } + } + freqs_cos = cos; + freqs_sin = sin; +} + +/** + * @brief apply rotary embedding + * @param[in] in input tensor + * @param[in] dim hidden dim size + * @param[in] from sequence order + * @param[in] max_timestep maximum timestep + */ +void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, + unsigned int max_timestep, RunLayerContext &context) { + nntrainer::Tensor out(in.getDim()); + float value = 0; + float transformed_value = 0.0; + unsigned int half_ = dim / 2; + + std::vector> freqs_cos = {}; + std::vector> freqs_sin = {}; + std::vector freqs; + + precompute_freqs(dim, max_timestep, freqs_cos, freqs_sin, freqs); + + std::vector cos_; + std::vector sin_; + + if (from >= max_timestep) { + cos_.resize(dim); + sin_.resize(dim); + + for (unsigned int i = 0; i < half_; ++i) { + float angle = from * freqs[i]; + cos_[i] = std::cos(angle); + cos_[i + half_] = std::cos(angle); // repeated 2 times + + sin_[i] = std::sin(angle); + sin_[i + half_] = std::sin(angle); // repeated 2 times + } + } else { + cos_.resize(max_timestep); + sin_.resize(max_timestep); + } + + unsigned int input_batch_size, input_height, input_width, input_channels; + input_batch_size = in.batch(); + input_height = in.height(); + input_width = in.width(); + input_channels = in.channel(); + + if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { + + unsigned int in_size = in.size(); + unsigned int out_size = out.size(); + float *data = in.getData(); + float *rdata = out.getData(); + + rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_, + input_batch_size, input_channels, input_height, input_width, + dim, from, max_timestep, in_size, out_size, context); + + } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + + unsigned int in_size = in.size(); + unsigned int out_size = out.size(); + _FP16 *data = in.getData<_FP16>(); + _FP16 *rdata = out.getData<_FP16>(); + + rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_, + input_batch_size, input_channels, input_height, input_width, + dim, from, max_timestep, in_size, out_size, context); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } + + if (from >= max_timestep) { + cos_.clear(); + sin_.clear(); + } + + in.copy(out); +} +} // namespace nntrainer \ No newline at end of file diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.h b/nntrainer/tensor/cl_operations/attention_kernel_interface.h new file mode 100644 index 0000000000..878561bdd9 --- /dev/null +++ b/nntrainer/tensor/cl_operations/attention_kernel_interface.h @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file blas_kernel_interface.h + * @date 28 August 2024 + * @brief Interface for attention OpenCL kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef __ATTENTION_KERNEL_INTERFACE_H__ +#define __ATTENTION_KERNEL_INTERFACE_H__ + +#include +#include + +namespace nntrainer { + +/** + * @brief Rotary Embedding kernel + * @param[in] in input tensor + * @param[in] dim hidden dim size + * @param[in] from sequence order + * @param[in] max_timestep maximum timestep + * @param[in] context layer context to get the resource manager and queue id + */ +void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, + unsigned int max_timestep, RunLayerContext &context); + +} // namespace nntrainer +#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */ \ No newline at end of file diff --git a/nntrainer/tensor/cl_operations/attention_kernels.cpp b/nntrainer/tensor/cl_operations/attention_kernels.cpp new file mode 100644 index 0000000000..355bb8ec65 --- /dev/null +++ b/nntrainer/tensor/cl_operations/attention_kernels.cpp @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file attention_kernels.cpp + * @date 28 August 2024 + * @brief Common attention OpenCL kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#include + +namespace nntrainer { +std::string rotary_emb_cl_kernel = R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void rotary_emb_cl(__global float *input, + __global float *output, + __global float *freqs_cos, + __global float *freqs_sin, + __global float *cos_, + __global float *sin_, + unsigned int batch, + unsigned int channel, + unsigned int height, + unsigned int width, + unsigned int dim, + unsigned int half_, + unsigned int max_timestep, + unsigned int from) { + unsigned int gid = get_global_id(0); + unsigned int gws = get_global_size(0); + + __global float *cos_ptr = cos_; + __global float *sin_ptr = sin_; + + float value = 0.0f; + float transformed_value = 0.0f; + + for (unsigned int b = 0; b < batch; b++) { + for (unsigned int c = 0; c < channel; c++) { + for (unsigned int h = 0; h < height; h++) { + if (from + h < max_timestep) { + unsigned idx = (from + h)*dim; + for(unsigned int i = idx; i < idx + dim; i++){ + cos_ptr[i - idx] = freqs_cos[i]; + sin_ptr[i - idx] = freqs_sin[i]; + } + } + for (unsigned int w = 0; w < width; w = w + dim) { + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + value = input[b * channel * height * width + c * height * width + h * width + span]; + if (k < half_) { + transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_]; + } else { + transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_]; + } + value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; + // printf("GPU Batch: %u, Height: %u, Channel: %u, Width: %u, K: %u, Span: %u, Value: %f, Transformed Value: %f, cos_ptr[k]: %f, sin_ptr[k]: %f\n", b, h, c, w, k, span, value, transformed_value, cos_ptr[k], sin_ptr[k]); + output[b * channel * height * width + c * height * width + h * width + span] = value; + } + } + } + } + } +} +)"; + +/** + * @brief defining global kernel objects + */ +opencl::Kernel kernel_rotary_emb; + +void rotary_emb_cl(float *in, float *out, + std::vector> freqs_cos, + std::vector> freqs_sin, + std::vector cos_, std::vector sin_, + unsigned int batch, unsigned int channel, + unsigned int height, unsigned int width, unsigned int dim, + unsigned int from, unsigned int max_timestep, + unsigned int in_size, unsigned int out_size, + RunLayerContext &context) { + bool result = false; + + do { + result = context.clCreateKernel( + rotary_emb_cl_kernel, context.LayerKernel::ROTARY_EMB, kernel_rotary_emb); + if (!result) { + printf("Failed to create kernel for rotary_emb_cl\n"); + break; + } + unsigned int cos_dim = cos_.size(); + unsigned int sin_dim = sin_.size(); + unsigned int freqs_cos_dim = freqs_cos.size(); + unsigned int freqs_sin_dim = freqs_sin.size(); + + size_t dim1_size = sizeof(float) * in_size; + size_t dim2_size = sizeof(float) * out_size; + size_t dim3_size = sizeof(float) * cos_dim; + size_t dim4_size = sizeof(float) * sin_dim; + size_t dim5_size = + sizeof(float) * freqs_cos_dim * dim; // max_timestep * dim + size_t dim6_size = sizeof(float) * freqs_sin_dim * dim; + + opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); + + opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr); + + opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr); + + opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr); + + opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true, + nullptr); + + opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true, + nullptr); + + std::vector freqs_cos_flat; + std::vector freqs_sin_flat; + for (const auto &row : freqs_cos) { + freqs_cos_flat.insert(freqs_cos_flat.end(), row.begin(), row.end()); + } + for (const auto &row : freqs_sin) { + freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end()); + } + + result = inputA.WriteData(context.command_queue_inst_, in); + if (!result) { + printf("Failed to write input data\n"); + break; + } + + result = inOutRes.WriteData(context.command_queue_inst_, out); + if (!result) { + printf("Failed to write output data\n"); + break; + } + + result = freqs_cosBuf.WriteData(context.command_queue_inst_, + freqs_cos_flat.data()); + if (!result) { + printf("Failed to write freqs cos data\n"); + break; + } + + result = freqs_sinBuf.WriteData(context.command_queue_inst_, + freqs_sin_flat.data()); + if (!result) { + printf("Failed to write freqs sin data\n"); + break; + } + + result = cosBuf.WriteData(context.command_queue_inst_, cos_.data()); + if (!result) { + printf("Failed to write cos data\n"); + break; + } + + result = sinBuf.WriteData(context.command_queue_inst_, sin_.data()); + if (!result) { + printf("Failed to write sin data\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + if (!result) { + printf("Failed to set inputA argument\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + if (!result) { + printf("Failed to set inOutRes argument\n"); + break; + } + + result = + kernel_rotary_emb.SetKernelArguments(2, &freqs_cosBuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set freqs_cosBuf argument\n"); + break; + } + + result = + kernel_rotary_emb.SetKernelArguments(3, &freqs_sinBuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set freqs_sinBuf argument\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set cosBuf argument\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set sinBuf argument\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(6, &batch, sizeof(int)); + if (!result) { + printf("Failed to set batch argument\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(7, &channel, sizeof(int)); + if (!result) { + printf("Failed to set channel argument\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(8, &height, sizeof(int)); + if (!result) { + printf("Failed to set height argument\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(9, &width, sizeof(int)); + if (!result) { + printf("Failed to set width argument\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(10, &dim, sizeof(int)); + if (!result) { + printf("Failed to set dim argument\n"); + break; + } + unsigned int half_ = dim / 2; + result = kernel_rotary_emb.SetKernelArguments(11, &half_, sizeof(int)); + if (!result) { + printf("Failed to set half argument\n"); + break; + } + + result = + kernel_rotary_emb.SetKernelArguments(12, &max_timestep, sizeof(int)); + if (!result) { + printf("Failed to set timestamp argument\n"); + break; + } + + result = kernel_rotary_emb.SetKernelArguments(13, &from, sizeof(int)); + if (!result) { + printf("Failed to set from argument\n"); + break; + } + + const int work_groups_count[3] = {1, 1, 1}; + const int work_group_size[3] = {32, 1, 1}; // test-value + result = context.command_queue_inst_.DispatchCommand( + kernel_rotary_emb, work_groups_count, work_group_size); + if (!result) { + printf("Failed to dispatch command\n"); + break; + } + + result = inOutRes.ReadData(context.command_queue_inst_, out); + if (!result) { + printf("Failed to read data\n"); + break; + } + + } while (false); +} +} // namespace nntrainer \ No newline at end of file diff --git a/nntrainer/tensor/cl_operations/attention_kernels.h b/nntrainer/tensor/cl_operations/attention_kernels.h new file mode 100644 index 0000000000..432b2322e1 --- /dev/null +++ b/nntrainer/tensor/cl_operations/attention_kernels.h @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file attention_kernels.h + * @date 28 August 2024 + * @brief Common attention OpenCL kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef __ATTENTION_KERNELS_H__ +#define __ATTENTION_KERNELS_H__ + +#include +#include +#include +#include + +namespace nntrainer { + +/** + * @brief declaring global kernel objects + */ +extern opencl::Kernel kernel_rotary_emb; + +/** + * @brief Rotary Embedding process + * @param[in] in __fp16 * input + * @param[in] out __fp16 * output + * @param[out] freqs_cos cosine of the frequencies + * @param[out] freqs_sin sine of the frequencies + * @param[in] cos_ vector of cos values + * @param[in] sin_ vector of sin values + * @param[in] batch size of batch + * @param[in] channel channel of input + * @param[in] height height of input + * @param[in] width width of input + * @param[in] dim hidden dim size + * @param[in] from sequence order + * @param[in] max_timestep max timestep + * @param[in] in_size size of input + * @param[in] out_size size of output + * @param[in] context RunLayerContext reference + */ +void rotary_emb_cl(float *in, float *out, + std::vector> freqs_cos, + std::vector> freqs_sin, + std::vector cos_, std::vector sin_, + unsigned int batch, unsigned int channel, + unsigned int height, unsigned int width, unsigned int dim, + unsigned int from, unsigned int max_timestamp, + unsigned int in_size, unsigned int out_size, + RunLayerContext &context); + +#ifdef ENABLE_FP16 +/** + * @brief declaring global fp16 kernel objects + */ +extern opencl::Kernel kernel_rotary_emb_fp16; + +/** + * @brief Rotary Embedding process + * @param[in] in __fp16 * input + * @param[in] out __fp16 * output + * @param[out] freqs_cos cosine of the frequencies + * @param[out] freqs_sin sine of the frequencies + * @param[in] cos_ vector of cos values + * @param[in] sin_ vector of sin values + * @param[in] batch size of batch + * @param[in] channel channel of input + * @param[in] height height of input + * @param[in] width width of input + * @param[in] dim hidden dim size + * @param[in] from sequence order + * @param[in] max_timestep max timestep + * @param[in] in_size size of input + * @param[in] out_size size of output + * @param[in] context RunLayerContext reference + */ +void rotary_emb_cl(__fp16 *in, __fp16 *out, + std::vector> freqs_cos, + std::vector> freqs_sin, + std::vector cos_, std::vector sin_, + unsigned int batch, unsigned int channel, + unsigned int height, unsigned int width, unsigned int dim, + unsigned int from, unsigned int max_timestamp, + unsigned int in_size, unsigned int out_size, + RunLayerContext &context); + +#endif + +} // namespace nntrainer +#endif /* __ATTENTION_KERNELS_H__ */ \ No newline at end of file diff --git a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp new file mode 100644 index 0000000000..b5d0ca5000 --- /dev/null +++ b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file attention_kernels_fp16.cpp + * @date 28 August 2024 + * @brief Common attention OpenCL fp16 kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#include + +namespace nntrainer { +std::string rotary_emb_cl_kernel_fp16 = R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void rotary_emb_cl_fp16(__global half *input, + __global half *output, + __global float *freqs_cos, + __global float *freqs_sin, + __global float *cos_, + __global float *sin_, + unsigned int batch, + unsigned int channel, + unsigned int height, + unsigned int width, + unsigned int dim, + unsigned int half_, + unsigned int max_timestep, + unsigned int from) { + unsigned int gid = get_global_id(0); + unsigned int gws = get_global_size(0); + + __global float *cos_ptr = cos_; + __global float *sin_ptr = sin_; + + float value = 0.0f; + float transformed_value = 0.0f; + + for (unsigned int b = 0; b < batch; b++) { + for (unsigned int c = 0; c < channel; c++) { + for (unsigned int h = 0; h < height; h++) { + if (from + h < max_timestep) { + unsigned idx = (from + h)*dim; + for(int i = idx; i < idx + dim; i++ ){ + cos_ptr[i - idx] = freqs_cos[i]; + sin_ptr[i - idx] = freqs_sin[i]; + } + } + + for (unsigned int w = 0; w < width; w = w + dim) { + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + value = (float)input[b * channel * height * width + c * height * width + h * width + span]; + if (k < half_) { + transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_]; + } else { + transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_]; + } + value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; + output[b * channel * height * width + c * height * width + h * width + span] = (half)value; + } + } + } + } + } +} +)"; + +/** + * @brief defining global kernel objects + */ +opencl::Kernel kernel_rotary_emb_fp16; + +void rotary_emb_cl(__fp16 *in, __fp16 *out, + std::vector> freqs_cos, + std::vector> freqs_sin, + std::vector cos_, std::vector sin_, + unsigned int batch, unsigned int channel, + unsigned int height, unsigned int width, unsigned int dim, + unsigned int from, unsigned int max_timestep, + unsigned int in_size, unsigned int out_size, + RunLayerContext &context) { + + bool result = false; + do { + result = context.clCreateKernel(rotary_emb_cl_kernel_fp16, + context.LayerKernel::ROTARY_EMB_FP16, + kernel_rotary_emb_fp16); + if (!result) { + printf("Failed to create kernel for rotary_emb_cl\n"); + break; + } + + unsigned int cos_dim = cos_.size(); + unsigned int sin_dim = sin_.size(); + unsigned int freqs_cos_dim = freqs_cos.size(); + unsigned int freqs_sin_dim = freqs_sin.size(); + + size_t dim1_size = sizeof(cl_half) * in_size; + size_t dim2_size = sizeof(cl_half) * out_size; + size_t dim3_size = sizeof(float) * cos_dim; + size_t dim4_size = sizeof(float) * sin_dim; + size_t dim5_size = sizeof(float) * freqs_cos_dim * dim; + size_t dim6_size = sizeof(float) * freqs_sin_dim * dim; + + opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); + + opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr); + + opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr); + + opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr); + + opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true, + nullptr); + + opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true, + nullptr); + + std::vector freqs_cos_flat; + std::vector freqs_sin_flat; + for (const auto &row : freqs_cos) { + freqs_cos_flat.insert(freqs_cos_flat.end(), row.begin(), row.end()); + } + for (const auto &row : freqs_sin) { + freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end()); + } + + result = inputA.WriteData(context.command_queue_inst_, in); + if (!result) { + printf("Failed to write input data\n"); + break; + } + + result = inOutRes.WriteData(context.command_queue_inst_, out); + if (!result) { + printf("Failed to write output data\n"); + break; + } + + result = freqs_cosBuf.WriteData(context.command_queue_inst_, + freqs_cos_flat.data()); + if (!result) { + printf("Failed to write cos data\n"); + break; + } + + result = freqs_sinBuf.WriteData(context.command_queue_inst_, + freqs_sin_flat.data()); + if (!result) { + printf("Failed to write sin data\n"); + break; + } + + result = cosBuf.WriteData(context.command_queue_inst_, cos_.data()); + if (!result) { + printf("Failed to write cos data\n"); + break; + } + + result = sinBuf.WriteData(context.command_queue_inst_, sin_.data()); + if (!result) { + printf("Failed to write sin data\n"); + break; + } + + result = + kernel_rotary_emb_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + if (!result) { + printf("Failed to set inputA argument\n"); + break; + } + + result = + kernel_rotary_emb_fp16.SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + if (!result) { + printf("Failed to set inOutRes argument\n"); + break; + } + + result = kernel_rotary_emb_fp16.SetKernelArguments(2, &freqs_cosBuf, + sizeof(cl_mem)); + if (!result) { + printf("Failed to set freqs_cosBuf argument\n"); + break; + } + + result = kernel_rotary_emb_fp16.SetKernelArguments(3, &freqs_sinBuf, + sizeof(cl_mem)); + if (!result) { + printf("Failed to set freqs_sinBuf argument\n"); + break; + } + + result = + kernel_rotary_emb_fp16.SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set cosBuf argument\n"); + break; + } + + result = + kernel_rotary_emb_fp16.SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set sinBuf argument\n"); + break; + } + + result = kernel_rotary_emb_fp16.SetKernelArguments(6, &batch, sizeof(int)); + if (!result) { + printf("Failed to set batch argument\n"); + break; + } + + result = + kernel_rotary_emb_fp16.SetKernelArguments(7, &channel, sizeof(int)); + if (!result) { + printf("Failed to set channel argument\n"); + break; + } + + result = kernel_rotary_emb_fp16.SetKernelArguments(8, &height, sizeof(int)); + if (!result) { + printf("Failed to set height argument\n"); + break; + } + + result = kernel_rotary_emb_fp16.SetKernelArguments(9, &width, sizeof(int)); + if (!result) { + printf("Failed to set width argument\n"); + break; + } + + result = kernel_rotary_emb_fp16.SetKernelArguments(10, &dim, sizeof(int)); + if (!result) { + printf("Failed to set dim argument\n"); + break; + } + unsigned int half_ = dim / 2; + result = kernel_rotary_emb_fp16.SetKernelArguments(11, &half_, sizeof(int)); + if (!result) { + printf("Failed to set half argument\n"); + break; + } + + result = + kernel_rotary_emb_fp16.SetKernelArguments(12, &max_timestep, sizeof(int)); + if (!result) { + printf("Failed to set timestamp argument\n"); + break; + } + + result = kernel_rotary_emb_fp16.SetKernelArguments(13, &from, sizeof(int)); + if (!result) { + printf("Failed to set from argument\n"); + break; + } + + const int work_groups_count[3] = {1, 1, 1}; + const int work_group_size[3] = {32, 1, 1}; // test-value + result = context.command_queue_inst_.DispatchCommand( + kernel_rotary_emb_fp16, work_groups_count, work_group_size); + if (!result) { + printf("Failed to dispatch command\n"); + break; + } + + result = inOutRes.ReadData(context.command_queue_inst_, out); + if (!result) { + printf("Failed to read data\n"); + break; + } + + } while (false); +} +} // namespace nntrainer \ No newline at end of file diff --git a/nntrainer/tensor/cl_operations/meson.build b/nntrainer/tensor/cl_operations/meson.build index 43e95f7fe9..3f186ec645 100644 --- a/nntrainer/tensor/cl_operations/meson.build +++ b/nntrainer/tensor/cl_operations/meson.build @@ -1,15 +1,19 @@ cl_op_sources = [ 'blas_kernels.cpp', 'blas_kernel_interface.cpp', + 'attention_kernel_interface.cpp', + 'attention_kernels.cpp', ] cl_op_headers = [ 'blas_kernel_interface.h', 'blas_kernel_strings.h', + 'attention_kernel_interface.h', ] if get_option('enable-fp16') cl_op_sources += 'blas_kernels_fp16.cpp' + cl_op_sources += 'attention_kernels_fp16.cpp' endif foreach s : cl_op_sources diff --git a/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp b/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp new file mode 100644 index 0000000000..c13cbc0148 --- /dev/null +++ b/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file testing_rotary_emb.cpp + * @date 28 August 2024 + * @brief Rotary Embedding CPU code + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#include "tensor.h" +#include + +/** + * @brief compute frequency for rotary embedding + * @param[in] dim hidden dim size + * @param[in] seq_len sequency length + * @param[out] freqs_cos cosine of the frequencies + * @param[out] freqs_sin sine of the frequencies + * @param[out] freqs base frequencies array to be used in computation of cos and + * sin values for each position in sequence + * @param[in] theta rotary angle + */ +void precompute_freqs(int dim, unsigned int seq_len, + std::vector> &freqs_cos, + std::vector> &freqs_sin, + std::vector &freqs, float theta = 10000.0) { + if (freqs_cos.empty()) { + unsigned int half_ = dim / 2; + for (unsigned int i = 0; i < half_; ++i) { + freqs.push_back(1.0 / + (std::pow(theta, (2 * i) / static_cast(dim)))); + } + + auto cos = std::vector>(); + cos.assign(seq_len, std::vector(dim, 0)); + + auto sin = std::vector>(); + sin.assign(seq_len, std::vector(dim, 0)); + + for (unsigned int i = 0; i < seq_len; ++i) { +#ifdef USE_NEON + nntrainer::calc_trigonometric_vals_dup(half_, freqs.data(), cos[i].data(), + sin[i].data(), i); +#else + for (unsigned int j = 0; j < half_; ++j) { + float angle = i * freqs[j]; + cos[i][j] = std::cos(angle); + cos[i][j + half_] = std::cos(angle); // repeated 2 times + + sin[i][j] = std::sin(angle); + sin[i][j + half_] = std::sin(angle); // repeated 2 times + } +#endif + } + freqs_cos = cos; + freqs_sin = sin; + } +} + +/** + * @brief apply rotary embedding + * @param[in] in input tensor + * @param[in] dim hidden dim size + * @param[in] from sequence order + * @param[in] max_timestep maximum timestep + */ +void apply_rotary_emb_tensor(nntrainer::Tensor &in, unsigned int dim, + unsigned int from, unsigned int max_timestep) { + nntrainer::Tensor out(in.getDim()); + float value = 0; + float transformed_value = 0.0; + unsigned int half_ = dim / 2; + + std::vector> freqs_cos = {}; + std::vector> freqs_sin = {}; + std::vector freqs; + + precompute_freqs(dim, max_timestep, freqs_cos, freqs_sin, freqs); + + std::vector cos_; + std::vector sin_; + + if (from >= max_timestep) { + cos_ = std::vector(dim); + sin_ = std::vector(dim); +#ifdef USE_NEON + nntrainer::calc_trigonometric_vals_dup(half_, freqs.data(), cos_.data(), + sin_.data(), from); +#else + for (unsigned int i = 0; i < half_; ++i) { + float angle = from * freqs[i]; + cos_[i] = std::cos(angle); + cos_[i + half_] = std::cos(angle); // repeated 2 times + + sin_[i] = std::sin(angle); + sin_[i + half_] = std::sin(angle); // repeated 2 times + } +#endif + } else { + cos_.resize(max_timestep); + sin_.resize(max_timestep); + } + + if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { + + unsigned int input_batch_size, input_height, input_width, input_channels; + input_batch_size = in.batch(); + input_height = in.height(); + input_width = in.width(); + input_channels = in.channel(); + + for (unsigned int b = 0; b < in.batch(); b++) { + for (unsigned int c = 0; c < in.channel(); c++) { + for (unsigned int h = 0; h < in.height(); h++) { + if (from + h < max_timestep) { + cos_ = freqs_cos[from + h]; + sin_ = freqs_sin[from + h]; + } + + for (unsigned int w = 0; w < in.width(); w = w + dim) { + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + if (span < in.width()) { + value = in.getValue(b, c, h, span); + if (k < half_) { + transformed_value = + -1.0 * in.getValue(b, c, h, span + half_); + } else { + transformed_value = in.getValue(b, c, h, span - half_); + } + value = value * cos_[k] + transformed_value * sin_[k]; + // printf("CPU Batch: %u, Channel: %u, Height: %u, Width: %u, K: + // %u, Span: %u, Value: %f, Transformed Value: %f, cos_ptr[k]: + // %f, sin_ptr[k]: %f\n ", b, c, h, w, k, span, value, + // transformed_value, cos_[k], sin_[k]); + out.setValue(b, c, h, span, value); + } + } + } + } + } + } + } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + for (unsigned int b = 0; b < in.batch(); b++) { + for (unsigned int c = 0; c < in.channel(); c++) { + for (unsigned int h = 0; h < in.height(); h++) { + if (from + h < max_timestep) { + cos_ = freqs_cos[from + h]; + sin_ = freqs_sin[from + h]; + } + for (unsigned int w = 0; w < in.width(); w = w + dim) { +#ifdef USE_NEON + nntrainer::compute_rotary_embedding_value( + dim, half_, w, in.getData<_FP16>() + in.getIndex(b, c, h, 0), + out.getData<_FP16>() + out.getIndex(b, c, h, 0), cos_.data(), + sin_.data()); +#else + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + value = static_cast(in.getValue<_FP16>(b, c, h, span)); + + if (k < half_) { + transformed_value = + -1.0 * + static_cast(in.getValue<_FP16>(b, c, h, half_ + span)); + } else { + transformed_value = + static_cast(in.getValue<_FP16>(b, c, h, span - half_)); + } + out.setValue(b, c, h, span, + static_cast<_FP16>(value * cos_[k] + + transformed_value * sin_[k])); + } +#endif + } + } + } + } +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } + + if (from >= max_timestep) { + cos_.clear(); + sin_.clear(); + } + + in.copy(out); +} \ No newline at end of file diff --git a/test/jni/Android.mk b/test/jni/Android.mk index 153b4eb840..faaba46f45 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -499,6 +499,22 @@ LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer LOCAL_STATIC_LIBRARIES := googletest_main test_util include $(BUILD_EXECUTABLE) +include $(CLEAR_VARS) + +LOCAL_MODULE := unittest_attention_kernels_cl +LOCAL_CFLAGS := -Igoogletest/include -I../include -I../unittest/layers -I../../nntrainer/layers/loss -pthread -fexceptions -fopenmp -static-openmp -DMIN_CPP_VERSION=201703L -DNNTR_NUM_THREADS=1 -D__LOGGING__=1 -DENABLE_TEST=1 -DREDUCE_TOLERANCE=1 -march=armv8.2-a+fp16 -mfpu=neon-fp16 -mfloat-abi=softfp -O3 -frtti -DNDK_BUILD=1 -DENABLE_FP16=1 -DENABLE_OPENCL=1 +LOCAL_CXXFLAGS += -std=c++17 -frtti -fexceptions +LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp + +LOCAL_SRC_FILES := \ + ../unittest/unittest_attention_kernels_cl.cpp + +LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) + +LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer +LOCAL_STATIC_LIBRARIES := googletest_main test_util +include $(BUILD_EXECUTABLE) + # unittest_ccapi include $(CLEAR_VARS) diff --git a/test/unittest/unittest_attention_kernels_cl.cpp b/test/unittest/unittest_attention_kernels_cl.cpp new file mode 100644 index 0000000000..7a09e5cd54 --- /dev/null +++ b/test/unittest/unittest_attention_kernels_cl.cpp @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file unittest_attention_kernels_cl.cpp + * @date 28 August 2024 + * @brief Test setup for blas OpenCL kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + */ + +#include +#include +#include + +#include "nntrainer_test_util.h" +#include "util_func.h" +#include +#include +#include +#include + +#include "testing_rotarty_emb.cpp" + +#define EXPECT_IN_RANGE(VAL, MIN, MAX) \ + EXPECT_GE((VAL), (MIN)); \ + EXPECT_LE((VAL), (MAX)) + +using namespace nntrainer; + +static RunLayerContext setUpGpuContext() { + + auto &ac = nntrainer::ClContext::Global(); + auto rc = RunLayerContext(); + + return rc; +} + +TEST(attention_kernels, rotary_emb_kernel_FP32) { + RunLayerContext rc = setUpGpuContext(); + + int batch = 1; + int channel = 1; + int height = 4; + int width = 4; + + unsigned int dim = 2; + unsigned int from = 4; + unsigned int max_timestep = 4; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height, width, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + + B_fp32.copy(A_fp32); + + // std::cout << "\nA_fp32 and B_fp32 before rotary embedding:" << std::endl; + // for (unsigned int i = 0; i < A_fp32.size(); ++i) { + // std::cout << "Element " << i << " -> " << *(A_fp32.getData() + i) + // <<"\t"<<*(B_fp32.getData() + i)<< std::endl; + // } + + apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc); + apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep); + + float mseErrorNeon_fp32 = + mse(A_fp32.getData(), B_fp32.getData(), A_fp32.size()); + + double cosSimNeon_fp32 = cosine_similarity( + A_fp32.getData(), B_fp32.getData(), A_fp32.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon_fp32, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon_fp32, 0.99, 1); +} + +TEST(attention_kernels, rotary_emb_kernel_FP32_case2) { + RunLayerContext rc = setUpGpuContext(); + + int batch = 4; + int channel = 4; + int height = 8; + int width = 8; + + unsigned int dim = 2; + unsigned int from = 2; + unsigned int max_timestep = 4; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height, width, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + + B_fp32.copy(A_fp32); + + apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc); + apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep); + + float mseErrorNeon_fp32 = + mse(A_fp32.getData(), B_fp32.getData(), A_fp32.size()); + + double cosSimNeon_fp32 = cosine_similarity( + A_fp32.getData(), B_fp32.getData(), A_fp32.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon_fp32, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon_fp32, 0.99, 1); +} + +TEST(attention_kernels, rotary_emb_kernel_FP16) { + RunLayerContext rc = setUpGpuContext(); + + int batch = 1; + int channel = 1; + int height = 4; + int width = 4; + + unsigned int dim = 2; + unsigned int from = 4; + unsigned int max_timestep = 4; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::Tensor A_fp16(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor B_fp16(batch, channel, height, width, t_type_nchw_fp16); + + GEN_TEST_INPUT(A_fp16, i * (batch * height * channel) * alpha + + j * (batch * height) * alpha + k * (width)*alpha + + l + 1); + + B_fp16.copy(A_fp16); + + apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc); + apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep); + + float mseErrorNeon_fp16 = mse<__fp16>( + A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size()); + + double cosSimNeon_fp16 = cosine_similarity<__fp16>( + A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon_fp16, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon_fp16, 0.99, 1); +} + +TEST(attention_kernels, rotary_emb_kernel_FP16_case2) { + RunLayerContext rc = setUpGpuContext(); + + int batch = 4; + int channel = 4; + int height = 8; + int width = 8; + + unsigned int dim = 4; + unsigned int from = 4; + unsigned int max_timestep = 8; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::Tensor A_fp16(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor B_fp16(batch, channel, height, width, t_type_nchw_fp16); + + GEN_TEST_INPUT(A_fp16, i * (batch * height * channel) * alpha + + j * (batch * height) * alpha + k * (width)*alpha + + l + 1); + + B_fp16.copy(A_fp16); + + apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc); + apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep); + + float mseErrorNeon_fp16 = mse<__fp16>( + A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size()); + + double cosSimNeon_fp16 = cosine_similarity<__fp16>( + A_fp16.getData<__fp16>(), B_fp16.getData<__fp16>(), A_fp16.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon_fp16, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon_fp16, 0.99, 1); +} + +GTEST_API_ int main(int argc, char **argv) { + int result = -1; + + try { + testing::InitGoogleTest(&argc, argv); + } catch (...) { + std::cerr << "Error during InitGoogleTest" << std::endl; + return 0; + } + + try { + result = RUN_ALL_TESTS(); + } catch (...) { + std::cerr << "Error during RUN_ALL_TESTS()" << std::endl; + } + + return result; +} From a23e29dd43b59df7e9dfdb5e3293b8d3530caf5c Mon Sep 17 00:00:00 2001 From: Yash Singh Date: Wed, 28 Aug 2024 17:59:58 +0530 Subject: [PATCH 2/6] [Trivial] New line at the end Added newline at the end in new files. Signed-off-by: Yash Singh --- nntrainer/tensor/cl_operations/attention_kernel_interface.cpp | 2 +- nntrainer/tensor/cl_operations/attention_kernel_interface.h | 2 +- nntrainer/tensor/cl_operations/attention_kernels.cpp | 2 +- nntrainer/tensor/cl_operations/attention_kernels.h | 2 +- nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp | 2 +- nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp index cf28840176..155127f472 100644 --- a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp @@ -133,4 +133,4 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, in.copy(out); } -} // namespace nntrainer \ No newline at end of file +} // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.h b/nntrainer/tensor/cl_operations/attention_kernel_interface.h index 878561bdd9..b287cb0a47 100644 --- a/nntrainer/tensor/cl_operations/attention_kernel_interface.h +++ b/nntrainer/tensor/cl_operations/attention_kernel_interface.h @@ -31,4 +31,4 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, unsigned int max_timestep, RunLayerContext &context); } // namespace nntrainer -#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */ \ No newline at end of file +#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */ diff --git a/nntrainer/tensor/cl_operations/attention_kernels.cpp b/nntrainer/tensor/cl_operations/attention_kernels.cpp index 355bb8ec65..387ceb50a7 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels.cpp @@ -270,4 +270,4 @@ void rotary_emb_cl(float *in, float *out, } while (false); } -} // namespace nntrainer \ No newline at end of file +} // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/attention_kernels.h b/nntrainer/tensor/cl_operations/attention_kernels.h index 432b2322e1..97e2a98cea 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels.h +++ b/nntrainer/tensor/cl_operations/attention_kernels.h @@ -93,4 +93,4 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, #endif } // namespace nntrainer -#endif /* __ATTENTION_KERNELS_H__ */ \ No newline at end of file +#endif /* __ATTENTION_KERNELS_H__ */ diff --git a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp index b5d0ca5000..16640821b7 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp @@ -276,4 +276,4 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, } while (false); } -} // namespace nntrainer \ No newline at end of file +} // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp b/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp index c13cbc0148..3bf8f8ab61 100644 --- a/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp +++ b/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp @@ -192,4 +192,4 @@ void apply_rotary_emb_tensor(nntrainer::Tensor &in, unsigned int dim, } in.copy(out); -} \ No newline at end of file +} From 473231ca1fcdaf543d54f16d26a7cebfa713c3fa Mon Sep 17 00:00:00 2001 From: Yash Singh Date: Thu, 29 Aug 2024 12:04:51 +0530 Subject: [PATCH 3/6] [Trivial] Unnecessary comments Removed Comments removed from the code. Signed-off-by: Yash Singh --- .../tensor/cl_operations/attention_kernels.cpp | 1 - .../cl_operations/attention_kernels_fp16.cpp | 4 ++-- .../tensor/cl_operations/testing_rotarty_emb.cpp | 14 -------------- test/unittest/unittest_attention_kernels_cl.cpp | 6 ------ 4 files changed, 2 insertions(+), 23 deletions(-) diff --git a/nntrainer/tensor/cl_operations/attention_kernels.cpp b/nntrainer/tensor/cl_operations/attention_kernels.cpp index 387ceb50a7..9b5cb7e699 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels.cpp @@ -59,7 +59,6 @@ __kernel void rotary_emb_cl(__global float *input, transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_]; } value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; - // printf("GPU Batch: %u, Height: %u, Channel: %u, Width: %u, K: %u, Span: %u, Value: %f, Transformed Value: %f, cos_ptr[k]: %f, sin_ptr[k]: %f\n", b, h, c, w, k, span, value, transformed_value, cos_ptr[k], sin_ptr[k]); output[b * channel * height * width + c * height * width + h * width + span] = value; } } diff --git a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp index 16640821b7..c6b1fbb263 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp @@ -144,14 +144,14 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, result = freqs_cosBuf.WriteData(context.command_queue_inst_, freqs_cos_flat.data()); if (!result) { - printf("Failed to write cos data\n"); + printf("Failed to write freqs cos data\n"); break; } result = freqs_sinBuf.WriteData(context.command_queue_inst_, freqs_sin_flat.data()); if (!result) { - printf("Failed to write sin data\n"); + printf("Failed to write freqs sin data\n"); break; } diff --git a/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp b/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp index 3bf8f8ab61..4ebde87332 100644 --- a/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp +++ b/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp @@ -42,10 +42,6 @@ void precompute_freqs(int dim, unsigned int seq_len, sin.assign(seq_len, std::vector(dim, 0)); for (unsigned int i = 0; i < seq_len; ++i) { -#ifdef USE_NEON - nntrainer::calc_trigonometric_vals_dup(half_, freqs.data(), cos[i].data(), - sin[i].data(), i); -#else for (unsigned int j = 0; j < half_; ++j) { float angle = i * freqs[j]; cos[i][j] = std::cos(angle); @@ -54,7 +50,6 @@ void precompute_freqs(int dim, unsigned int seq_len, sin[i][j] = std::sin(angle); sin[i][j + half_] = std::sin(angle); // repeated 2 times } -#endif } freqs_cos = cos; freqs_sin = sin; @@ -87,10 +82,6 @@ void apply_rotary_emb_tensor(nntrainer::Tensor &in, unsigned int dim, if (from >= max_timestep) { cos_ = std::vector(dim); sin_ = std::vector(dim); -#ifdef USE_NEON - nntrainer::calc_trigonometric_vals_dup(half_, freqs.data(), cos_.data(), - sin_.data(), from); -#else for (unsigned int i = 0; i < half_; ++i) { float angle = from * freqs[i]; cos_[i] = std::cos(angle); @@ -99,7 +90,6 @@ void apply_rotary_emb_tensor(nntrainer::Tensor &in, unsigned int dim, sin_[i] = std::sin(angle); sin_[i + half_] = std::sin(angle); // repeated 2 times } -#endif } else { cos_.resize(max_timestep); sin_.resize(max_timestep); @@ -133,10 +123,6 @@ void apply_rotary_emb_tensor(nntrainer::Tensor &in, unsigned int dim, transformed_value = in.getValue(b, c, h, span - half_); } value = value * cos_[k] + transformed_value * sin_[k]; - // printf("CPU Batch: %u, Channel: %u, Height: %u, Width: %u, K: - // %u, Span: %u, Value: %f, Transformed Value: %f, cos_ptr[k]: - // %f, sin_ptr[k]: %f\n ", b, c, h, w, k, span, value, - // transformed_value, cos_[k], sin_[k]); out.setValue(b, c, h, span, value); } } diff --git a/test/unittest/unittest_attention_kernels_cl.cpp b/test/unittest/unittest_attention_kernels_cl.cpp index 7a09e5cd54..d2a26cc9d3 100644 --- a/test/unittest/unittest_attention_kernels_cl.cpp +++ b/test/unittest/unittest_attention_kernels_cl.cpp @@ -65,12 +65,6 @@ TEST(attention_kernels, rotary_emb_kernel_FP32) { B_fp32.copy(A_fp32); - // std::cout << "\nA_fp32 and B_fp32 before rotary embedding:" << std::endl; - // for (unsigned int i = 0; i < A_fp32.size(); ++i) { - // std::cout << "Element " << i << " -> " << *(A_fp32.getData() + i) - // <<"\t"<<*(B_fp32.getData() + i)<< std::endl; - // } - apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc); apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep); From 46e17c1269444abe85e767680d10a771a3a36a09 Mon Sep 17 00:00:00 2001 From: Yash Singh Date: Tue, 3 Sep 2024 17:09:35 +0530 Subject: [PATCH 4/6] [GPU/OpenCl] Kernel optimization Kernel Optimized for GPU. Some trivial changes in code. Signed-off-by: Yash Singh --- .../attention_kernel_interface.cpp | 29 ++++++----- .../cl_operations/attention_kernels.cpp | 49 +++++++++---------- .../cl_operations/attention_kernels_fp16.cpp | 48 +++++++++--------- .../cl_operations/testing_rotarty_emb.cpp | 31 ++++++------ 4 files changed, 79 insertions(+), 78 deletions(-) diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp index 155127f472..85c3331edd 100644 --- a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp @@ -24,7 +24,7 @@ namespace nntrainer { * @param[out] freqs base frequencies array to be used in the future computation * @param[in] theta rotary angle */ -void precompute_freqs(int dim, unsigned int seq_len, +void precompute_freqs(unsigned int dim, unsigned int seq_len, std::vector> &freqs_cos, std::vector> &freqs_sin, std::vector &freqs, float theta = 10000.0) { @@ -33,24 +33,24 @@ void precompute_freqs(int dim, unsigned int seq_len, freqs.push_back(1.0 / (std::pow(theta, (2 * i) / static_cast(dim)))); } - auto cos = std::vector>(); - cos.assign(seq_len, std::vector(dim, 0)); + auto cos_vec = std::vector>(); + cos_vec.assign(seq_len, std::vector(dim, 0)); - auto sin = std::vector>(); - sin.assign(seq_len, std::vector(dim, 0)); + auto sin_vec = std::vector>(); + sin_vec.assign(seq_len, std::vector(dim, 0)); for (unsigned int i = 0; i < seq_len; ++i) { for (unsigned int j = 0; j < half_; ++j) { float angle = i * freqs[j]; - cos[i][j] = std::cos(angle); - cos[i][j + half_] = std::cos(angle); // repeated 2 times + cos_vec[i][j] = std::cos(angle); + cos_vec[i][j + half_] = std::cos(angle); // repeated 2 times - sin[i][j] = std::sin(angle); - sin[i][j + half_] = std::sin(angle); // repeated 2 times + sin_vec[i][j] = std::sin(angle); + sin_vec[i][j + half_] = std::sin(angle); // repeated 2 times } } - freqs_cos = cos; - freqs_sin = sin; + freqs_cos = cos_vec; + freqs_sin = sin_vec; } /** @@ -59,12 +59,15 @@ void precompute_freqs(int dim, unsigned int seq_len, * @param[in] dim hidden dim size * @param[in] from sequence order * @param[in] max_timestep maximum timestep + * @param[in] context layer context to get the resource manager and queue id + * + * @todo Calling precompute_freqs in finalize to reduce code redundancy. */ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, unsigned int max_timestep, RunLayerContext &context) { nntrainer::Tensor out(in.getDim()); - float value = 0; - float transformed_value = 0.0; + float value = 0.0f; + float transformed_value = 0.0f; unsigned int half_ = dim / 2; std::vector> freqs_cos = {}; diff --git a/nntrainer/tensor/cl_operations/attention_kernels.cpp b/nntrainer/tensor/cl_operations/attention_kernels.cpp index 9b5cb7e699..5fd646b7c1 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels.cpp @@ -30,37 +30,36 @@ __kernel void rotary_emb_cl(__global float *input, unsigned int half_, unsigned int max_timestep, unsigned int from) { - unsigned int gid = get_global_id(0); - unsigned int gws = get_global_size(0); - __global float *cos_ptr = cos_; __global float *sin_ptr = sin_; float value = 0.0f; float transformed_value = 0.0f; - for (unsigned int b = 0; b < batch; b++) { - for (unsigned int c = 0; c < channel; c++) { - for (unsigned int h = 0; h < height; h++) { - if (from + h < max_timestep) { - unsigned idx = (from + h)*dim; - for(unsigned int i = idx; i < idx + dim; i++){ - cos_ptr[i - idx] = freqs_cos[i]; - sin_ptr[i - idx] = freqs_sin[i]; - } + unsigned int b = get_global_id(0); + unsigned int c = get_global_id(1); + + if(b < batch && c < channel){ + for (unsigned int h = 0; h < height; h++) { + if (from + h < max_timestep) { + unsigned idx = (from + h)*dim; + for(unsigned int i = idx; i < idx + dim; i++){ + cos_ptr[i - idx] = freqs_cos[i]; + sin_ptr[i - idx] = freqs_sin[i]; } - for (unsigned int w = 0; w < width; w = w + dim) { - for (unsigned int k = 0; k < dim; k++) { - unsigned int span = w + k; - value = input[b * channel * height * width + c * height * width + h * width + span]; - if (k < half_) { - transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_]; - } else { - transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_]; - } - value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; - output[b * channel * height * width + c * height * width + h * width + span] = value; + } + + for (unsigned int w = 0; w < width; w = w + dim) { + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + value = input[b * channel * height * width + c * height * width + h * width + span]; + if (k < half_) { + transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_]; + } else { + transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_]; } + value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; + output[b * channel * height * width + c * height * width + h * width + span] = value; } } } @@ -252,8 +251,8 @@ void rotary_emb_cl(float *in, float *out, break; } - const int work_groups_count[3] = {1, 1, 1}; - const int work_group_size[3] = {32, 1, 1}; // test-value + const int work_groups_count[3] = {(int)batch, (int)channel, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value result = context.command_queue_inst_.DispatchCommand( kernel_rotary_emb, work_groups_count, work_group_size); if (!result) { diff --git a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp index c6b1fbb263..7c2c995020 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp @@ -30,38 +30,36 @@ __kernel void rotary_emb_cl_fp16(__global half *input, unsigned int half_, unsigned int max_timestep, unsigned int from) { - unsigned int gid = get_global_id(0); - unsigned int gws = get_global_size(0); - __global float *cos_ptr = cos_; __global float *sin_ptr = sin_; float value = 0.0f; float transformed_value = 0.0f; - for (unsigned int b = 0; b < batch; b++) { - for (unsigned int c = 0; c < channel; c++) { - for (unsigned int h = 0; h < height; h++) { - if (from + h < max_timestep) { - unsigned idx = (from + h)*dim; - for(int i = idx; i < idx + dim; i++ ){ - cos_ptr[i - idx] = freqs_cos[i]; - sin_ptr[i - idx] = freqs_sin[i]; - } + unsigned int b = get_global_id(0); + unsigned int c = get_global_id(1); + + if(b < batch && c < channel){ + for (unsigned int h = 0; h < height; h++) { + if (from + h < max_timestep) { + unsigned idx = (from + h)*dim; + for(int i = idx; i < idx + dim; i++ ){ + cos_ptr[i - idx] = freqs_cos[i]; + sin_ptr[i - idx] = freqs_sin[i]; } + } - for (unsigned int w = 0; w < width; w = w + dim) { - for (unsigned int k = 0; k < dim; k++) { - unsigned int span = w + k; - value = (float)input[b * channel * height * width + c * height * width + h * width + span]; - if (k < half_) { - transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_]; - } else { - transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_]; - } - value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; - output[b * channel * height * width + c * height * width + h * width + span] = (half)value; + for (unsigned int w = 0; w < width; w = w + dim) { + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + value = (float)input[b * channel * height * width + c * height * width + h * width + span]; + if (k < half_) { + transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_]; + } else { + transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_]; } + value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; + output[b * channel * height * width + c * height * width + h * width + span] = (half)value; } } } @@ -259,8 +257,8 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, break; } - const int work_groups_count[3] = {1, 1, 1}; - const int work_group_size[3] = {32, 1, 1}; // test-value + const int work_groups_count[3] = {(int)batch, (int)channel, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value result = context.command_queue_inst_.DispatchCommand( kernel_rotary_emb_fp16, work_groups_count, work_group_size); if (!result) { diff --git a/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp b/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp index 4ebde87332..d7bab6cc49 100644 --- a/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp +++ b/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp @@ -15,7 +15,8 @@ #include /** - * @brief compute frequency for rotary embedding + * @brief Testing code for CPU results and compute frequency for rotary + * embedding * @param[in] dim hidden dim size * @param[in] seq_len sequency length * @param[out] freqs_cos cosine of the frequencies @@ -24,7 +25,7 @@ * sin values for each position in sequence * @param[in] theta rotary angle */ -void precompute_freqs(int dim, unsigned int seq_len, +void precompute_freqs(unsigned int dim, unsigned int seq_len, std::vector> &freqs_cos, std::vector> &freqs_sin, std::vector &freqs, float theta = 10000.0) { @@ -35,29 +36,29 @@ void precompute_freqs(int dim, unsigned int seq_len, (std::pow(theta, (2 * i) / static_cast(dim)))); } - auto cos = std::vector>(); - cos.assign(seq_len, std::vector(dim, 0)); + auto cos_vec = std::vector>(); + cos_vec.assign(seq_len, std::vector(dim, 0)); - auto sin = std::vector>(); - sin.assign(seq_len, std::vector(dim, 0)); + auto sin_vec = std::vector>(); + sin_vec.assign(seq_len, std::vector(dim, 0)); for (unsigned int i = 0; i < seq_len; ++i) { for (unsigned int j = 0; j < half_; ++j) { float angle = i * freqs[j]; - cos[i][j] = std::cos(angle); - cos[i][j + half_] = std::cos(angle); // repeated 2 times + cos_vec[i][j] = std::cos(angle); + cos_vec[i][j + half_] = std::cos(angle); // repeated 2 times - sin[i][j] = std::sin(angle); - sin[i][j + half_] = std::sin(angle); // repeated 2 times + sin_vec[i][j] = std::sin(angle); + sin_vec[i][j + half_] = std::sin(angle); // repeated 2 times } } - freqs_cos = cos; - freqs_sin = sin; + freqs_cos = cos_vec; + freqs_sin = sin_vec; } } /** - * @brief apply rotary embedding + * @brief Testing code for CPU results and apply rotary embedding * @param[in] in input tensor * @param[in] dim hidden dim size * @param[in] from sequence order @@ -66,8 +67,8 @@ void precompute_freqs(int dim, unsigned int seq_len, void apply_rotary_emb_tensor(nntrainer::Tensor &in, unsigned int dim, unsigned int from, unsigned int max_timestep) { nntrainer::Tensor out(in.getDim()); - float value = 0; - float transformed_value = 0.0; + float value = 0.0f; + float transformed_value = 0.0f; unsigned int half_ = dim / 2; std::vector> freqs_cos = {}; From cc86c88445348986618c9d958004718b0c5762b2 Mon Sep 17 00:00:00 2001 From: Yash Singh Date: Tue, 8 Oct 2024 12:43:17 +0530 Subject: [PATCH 5/6] [GPU/Enhance] Registering Attention kernels and removind cl_context dependency Added registerCLKernel function to register custom OpenCL kernels as well as in-house kernels. Modified attention kernels to remove cl_context related dependencies. Added initAttentionCLKernels function to register default attention kernels. Modified unittest to remove layer_context dependency attention_kernel_strings.h added to handle attention kernels at one place. Rebased the PR with current log. Signed-off-by: Yash Singh --- nntrainer/cl_context.cpp | 16 ++ nntrainer/cl_context.h | 8 + .../attention_kernel_interface.cpp | 7 +- .../attention_kernel_interface.h | 5 +- .../cl_operations/attention_kernel_strings.h | 133 ++++++++++++++++ .../cl_operations/attention_kernels.cpp | 139 ++++++----------- .../tensor/cl_operations/attention_kernels.h | 20 +-- .../cl_operations/attention_kernels_fp16.cpp | 146 ++++++------------ nntrainer/tensor/cl_operations/meson.build | 1 + .../unittest_attention_kernels_cl.cpp | 23 ++- 10 files changed, 272 insertions(+), 226 deletions(-) create mode 100644 nntrainer/tensor/cl_operations/attention_kernel_strings.h diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index 5ecf80f838..10e3ecdbb7 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -149,6 +150,21 @@ void ClContext::initBlasClKernels() { blas_kernels_initialized = true; } +void ClContext::initAttentionClKernels() { + if (attention_kernels_initialized) { + ml_logi("ClContext: Default attention kernels already registered and " + "initialized"); + return; + } + + registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl"); + +#ifdef ENABLE_FP16 + registerClKernel(rotary_emb_cl_kernel_fp16_, "rotary_emb_cl_fp16"); +#endif + attention_kernels_initialized = true; +} + const ClContext::SharedPtrClKernel ClContext::registerClKernel(std::string kernel_string, std::string kernel_name) { diff --git a/nntrainer/cl_context.h b/nntrainer/cl_context.h index 7683453221..025365546b 100644 --- a/nntrainer/cl_context.h +++ b/nntrainer/cl_context.h @@ -211,6 +211,11 @@ class ClContext { */ void initBlasClKernels(); + /** + * @brief Initialize and register all attention OpenCl kernels + */ + void initAttentionClKernels(); + /** * @brief destructor to release opencl commandQueue */ @@ -229,6 +234,9 @@ class ClContext { // flag to check default blas kernels registered or not bool blas_kernels_initialized = false; + // flag to check default attention kernels registered or not + bool attention_kernels_initialized = false; + FactoryMap factory_map; template struct isSupportedHelper; diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp index 85c3331edd..658e2a3d91 100644 --- a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp @@ -59,12 +59,11 @@ void precompute_freqs(unsigned int dim, unsigned int seq_len, * @param[in] dim hidden dim size * @param[in] from sequence order * @param[in] max_timestep maximum timestep - * @param[in] context layer context to get the resource manager and queue id * * @todo Calling precompute_freqs in finalize to reduce code redundancy. */ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, - unsigned int max_timestep, RunLayerContext &context) { + unsigned int max_timestep) { nntrainer::Tensor out(in.getDim()); float value = 0.0f; float transformed_value = 0.0f; @@ -111,7 +110,7 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_, input_batch_size, input_channels, input_height, input_width, - dim, from, max_timestep, in_size, out_size, context); + dim, from, max_timestep, in_size, out_size); } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 @@ -123,7 +122,7 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_, input_batch_size, input_channels, input_height, input_width, - dim, from, max_timestep, in_size, out_size, context); + dim, from, max_timestep, in_size, out_size); #else throw std::invalid_argument("Error: enable-fp16 is not enabled"); #endif diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.h b/nntrainer/tensor/cl_operations/attention_kernel_interface.h index b287cb0a47..fe9c0f8b0c 100644 --- a/nntrainer/tensor/cl_operations/attention_kernel_interface.h +++ b/nntrainer/tensor/cl_operations/attention_kernel_interface.h @@ -14,8 +14,8 @@ #ifndef __ATTENTION_KERNEL_INTERFACE_H__ #define __ATTENTION_KERNEL_INTERFACE_H__ -#include #include +#include namespace nntrainer { @@ -25,10 +25,9 @@ namespace nntrainer { * @param[in] dim hidden dim size * @param[in] from sequence order * @param[in] max_timestep maximum timestep - * @param[in] context layer context to get the resource manager and queue id */ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, - unsigned int max_timestep, RunLayerContext &context); + unsigned int max_timestep); } // namespace nntrainer #endif /* __ATTENTION_KERNEL_INTERFACE_H__ */ diff --git a/nntrainer/tensor/cl_operations/attention_kernel_strings.h b/nntrainer/tensor/cl_operations/attention_kernel_strings.h new file mode 100644 index 0000000000..d58fd75035 --- /dev/null +++ b/nntrainer/tensor/cl_operations/attention_kernel_strings.h @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file attention_kernel_strings.h + * @date 8 October 2024 + * @brief All attention OpenCL kernel strings + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef __ATTENTION_KERNEL_STRINGS_H__ +#define __ATTENTION_KERNEL_STRINGS_H__ + +#include + +namespace nntrainer { +static const std::string rotary_emb_cl_kernel_ = R"( + + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void rotary_emb_cl(__global float *input, + __global float *output, + __global float *freqs_cos, + __global float *freqs_sin, + __global float *cos_, + __global float *sin_, + unsigned int batch, + unsigned int channel, + unsigned int height, + unsigned int width, + unsigned int dim, + unsigned int half_, + unsigned int max_timestep, + unsigned int from) { + __global float *cos_ptr = cos_; + __global float *sin_ptr = sin_; + + float value = 0.0f; + float transformed_value = 0.0f; + + unsigned int b = get_global_id(0); + unsigned int c = get_global_id(1); + + if(b < batch && c < channel){ + for (unsigned int h = 0; h < height; h++) { + if (from + h < max_timestep) { + unsigned idx = (from + h)*dim; + for(unsigned int i = idx; i < idx + dim; i++){ + cos_ptr[i - idx] = freqs_cos[i]; + sin_ptr[i - idx] = freqs_sin[i]; + } + } + + for (unsigned int w = 0; w < width; w = w + dim) { + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + value = input[b * channel * height * width + c * height * width + h * width + span]; + if (k < half_) { + transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_]; + } else { + transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_]; + } + value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; + output[b * channel * height * width + c * height * width + h * width + span] = value; + } + } + } + } +} +)"; + +#ifdef ENABLE_FP16 +static const std::string rotary_emb_cl_kernel_fp16_ = R"( + + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void rotary_emb_cl_fp16(__global half *input, + __global half *output, + __global float *freqs_cos, + __global float *freqs_sin, + __global float *cos_, + __global float *sin_, + unsigned int batch, + unsigned int channel, + unsigned int height, + unsigned int width, + unsigned int dim, + unsigned int half_, + unsigned int max_timestep, + unsigned int from) { + __global float *cos_ptr = cos_; + __global float *sin_ptr = sin_; + + float value = 0.0f; + float transformed_value = 0.0f; + + unsigned int b = get_global_id(0); + unsigned int c = get_global_id(1); + + if(b < batch && c < channel){ + for (unsigned int h = 0; h < height; h++) { + if (from + h < max_timestep) { + unsigned idx = (from + h)*dim; + for(int i = idx; i < idx + dim; i++ ){ + cos_ptr[i - idx] = freqs_cos[i]; + sin_ptr[i - idx] = freqs_sin[i]; + } + } + + for (unsigned int w = 0; w < width; w = w + dim) { + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + value = (float)input[b * channel * height * width + c * height * width + h * width + span]; + if (k < half_) { + transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_]; + } else { + transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_]; + } + value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; + output[b * channel * height * width + c * height * width + h * width + span] = (half)value; + } + } + } + } +} +)"; + +#endif +} // namespace nntrainer +#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */ diff --git a/nntrainer/tensor/cl_operations/attention_kernels.cpp b/nntrainer/tensor/cl_operations/attention_kernels.cpp index 5fd646b7c1..388cc0805f 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels.cpp @@ -11,66 +11,10 @@ * */ +#include #include namespace nntrainer { -std::string rotary_emb_cl_kernel = R"( - #pragma OPENCL EXTENSION cl_khr_fp16 : enable -__kernel void rotary_emb_cl(__global float *input, - __global float *output, - __global float *freqs_cos, - __global float *freqs_sin, - __global float *cos_, - __global float *sin_, - unsigned int batch, - unsigned int channel, - unsigned int height, - unsigned int width, - unsigned int dim, - unsigned int half_, - unsigned int max_timestep, - unsigned int from) { - __global float *cos_ptr = cos_; - __global float *sin_ptr = sin_; - - float value = 0.0f; - float transformed_value = 0.0f; - - unsigned int b = get_global_id(0); - unsigned int c = get_global_id(1); - - if(b < batch && c < channel){ - for (unsigned int h = 0; h < height; h++) { - if (from + h < max_timestep) { - unsigned idx = (from + h)*dim; - for(unsigned int i = idx; i < idx + dim; i++){ - cos_ptr[i - idx] = freqs_cos[i]; - sin_ptr[i - idx] = freqs_sin[i]; - } - } - - for (unsigned int w = 0; w < width; w = w + dim) { - for (unsigned int k = 0; k < dim; k++) { - unsigned int span = w + k; - value = input[b * channel * height * width + c * height * width + h * width + span]; - if (k < half_) { - transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_]; - } else { - transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_]; - } - value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; - output[b * channel * height * width + c * height * width + h * width + span] = value; - } - } - } - } -} -)"; - -/** - * @brief defining global kernel objects - */ -opencl::Kernel kernel_rotary_emb; void rotary_emb_cl(float *in, float *out, std::vector> freqs_cos, @@ -79,17 +23,16 @@ void rotary_emb_cl(float *in, float *out, unsigned int batch, unsigned int channel, unsigned int height, unsigned int width, unsigned int dim, unsigned int from, unsigned int max_timestep, - unsigned int in_size, unsigned int out_size, - RunLayerContext &context) { + unsigned int in_size, unsigned int out_size) { bool result = false; do { - result = context.clCreateKernel( - rotary_emb_cl_kernel, context.LayerKernel::ROTARY_EMB, kernel_rotary_emb); - if (!result) { - printf("Failed to create kernel for rotary_emb_cl\n"); + ClContext::SharedPtrClKernel kernel_rotaryEmb_ptr = + cl_context_ref.registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl"); + if (!kernel_rotaryEmb_ptr) { break; } + unsigned int cos_dim = cos_.size(); unsigned int sin_dim = sin_.size(); unsigned int freqs_cos_dim = freqs_cos.size(); @@ -103,18 +46,22 @@ void rotary_emb_cl(float *in, float *out, sizeof(float) * freqs_cos_dim * dim; // max_timestep * dim size_t dim6_size = sizeof(float) * freqs_sin_dim * dim; - opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); + opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true, + nullptr); - opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr); + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim2_size, true, + nullptr); - opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr); + opencl::Buffer cosBuf(cl_context_ref.context_inst_, dim3_size, true, + nullptr); - opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr); + opencl::Buffer sinBuf(cl_context_ref.context_inst_, dim4_size, true, + nullptr); - opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true, + opencl::Buffer freqs_cosBuf(cl_context_ref.context_inst_, dim5_size, true, nullptr); - opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true, + opencl::Buffer freqs_sinBuf(cl_context_ref.context_inst_, dim6_size, true, nullptr); std::vector freqs_cos_flat; @@ -126,126 +73,130 @@ void rotary_emb_cl(float *in, float *out, freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end()); } - result = inputA.WriteData(context.command_queue_inst_, in); + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); if (!result) { printf("Failed to write input data\n"); break; } - result = inOutRes.WriteData(context.command_queue_inst_, out); + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, out); if (!result) { printf("Failed to write output data\n"); break; } - result = freqs_cosBuf.WriteData(context.command_queue_inst_, + result = freqs_cosBuf.WriteData(cl_context_ref.command_queue_inst_, freqs_cos_flat.data()); if (!result) { printf("Failed to write freqs cos data\n"); break; } - result = freqs_sinBuf.WriteData(context.command_queue_inst_, + result = freqs_sinBuf.WriteData(cl_context_ref.command_queue_inst_, freqs_sin_flat.data()); if (!result) { printf("Failed to write freqs sin data\n"); break; } - result = cosBuf.WriteData(context.command_queue_inst_, cos_.data()); + result = cosBuf.WriteData(cl_context_ref.command_queue_inst_, cos_.data()); if (!result) { printf("Failed to write cos data\n"); break; } - result = sinBuf.WriteData(context.command_queue_inst_, sin_.data()); + result = sinBuf.WriteData(cl_context_ref.command_queue_inst_, sin_.data()); if (!result) { printf("Failed to write sin data\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + result = + kernel_rotaryEmb_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); if (!result) { printf("Failed to set inputA argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + result = + kernel_rotaryEmb_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); if (!result) { printf("Failed to set inOutRes argument\n"); break; } - result = - kernel_rotary_emb.SetKernelArguments(2, &freqs_cosBuf, sizeof(cl_mem)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(2, &freqs_cosBuf, + sizeof(cl_mem)); if (!result) { printf("Failed to set freqs_cosBuf argument\n"); break; } - result = - kernel_rotary_emb.SetKernelArguments(3, &freqs_sinBuf, sizeof(cl_mem)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(3, &freqs_sinBuf, + sizeof(cl_mem)); if (!result) { printf("Failed to set freqs_sinBuf argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); + result = + kernel_rotaryEmb_ptr->SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); if (!result) { printf("Failed to set cosBuf argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); + result = + kernel_rotaryEmb_ptr->SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); if (!result) { printf("Failed to set sinBuf argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(6, &batch, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(6, &batch, sizeof(int)); if (!result) { printf("Failed to set batch argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(7, &channel, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(7, &channel, sizeof(int)); if (!result) { printf("Failed to set channel argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(8, &height, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(8, &height, sizeof(int)); if (!result) { printf("Failed to set height argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(9, &width, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(9, &width, sizeof(int)); if (!result) { printf("Failed to set width argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(10, &dim, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(10, &dim, sizeof(int)); if (!result) { printf("Failed to set dim argument\n"); break; } unsigned int half_ = dim / 2; - result = kernel_rotary_emb.SetKernelArguments(11, &half_, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(11, &half_, sizeof(int)); if (!result) { printf("Failed to set half argument\n"); break; } result = - kernel_rotary_emb.SetKernelArguments(12, &max_timestep, sizeof(int)); + kernel_rotaryEmb_ptr->SetKernelArguments(12, &max_timestep, sizeof(int)); if (!result) { printf("Failed to set timestamp argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(13, &from, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(13, &from, sizeof(int)); if (!result) { printf("Failed to set from argument\n"); break; @@ -253,14 +204,14 @@ void rotary_emb_cl(float *in, float *out, const int work_groups_count[3] = {(int)batch, (int)channel, 1}; const int work_group_size[3] = {32, 32, 1}; // test-value - result = context.command_queue_inst_.DispatchCommand( - kernel_rotary_emb, work_groups_count, work_group_size); + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_rotaryEmb_ptr, work_groups_count, work_group_size); if (!result) { printf("Failed to dispatch command\n"); break; } - result = inOutRes.ReadData(context.command_queue_inst_, out); + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, out); if (!result) { printf("Failed to read data\n"); break; diff --git a/nntrainer/tensor/cl_operations/attention_kernels.h b/nntrainer/tensor/cl_operations/attention_kernels.h index 97e2a98cea..37a3a4428a 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels.h +++ b/nntrainer/tensor/cl_operations/attention_kernels.h @@ -14,17 +14,15 @@ #ifndef __ATTENTION_KERNELS_H__ #define __ATTENTION_KERNELS_H__ -#include +#include #include #include #include namespace nntrainer { -/** - * @brief declaring global kernel objects - */ -extern opencl::Kernel kernel_rotary_emb; +// get global cl_context to use in kernels +static ClContext cl_context_ref; /** * @brief Rotary Embedding process @@ -43,7 +41,6 @@ extern opencl::Kernel kernel_rotary_emb; * @param[in] max_timestep max timestep * @param[in] in_size size of input * @param[in] out_size size of output - * @param[in] context RunLayerContext reference */ void rotary_emb_cl(float *in, float *out, std::vector> freqs_cos, @@ -52,14 +49,9 @@ void rotary_emb_cl(float *in, float *out, unsigned int batch, unsigned int channel, unsigned int height, unsigned int width, unsigned int dim, unsigned int from, unsigned int max_timestamp, - unsigned int in_size, unsigned int out_size, - RunLayerContext &context); + unsigned int in_size, unsigned int out_size); #ifdef ENABLE_FP16 -/** - * @brief declaring global fp16 kernel objects - */ -extern opencl::Kernel kernel_rotary_emb_fp16; /** * @brief Rotary Embedding process @@ -78,7 +70,6 @@ extern opencl::Kernel kernel_rotary_emb_fp16; * @param[in] max_timestep max timestep * @param[in] in_size size of input * @param[in] out_size size of output - * @param[in] context RunLayerContext reference */ void rotary_emb_cl(__fp16 *in, __fp16 *out, std::vector> freqs_cos, @@ -87,8 +78,7 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, unsigned int batch, unsigned int channel, unsigned int height, unsigned int width, unsigned int dim, unsigned int from, unsigned int max_timestamp, - unsigned int in_size, unsigned int out_size, - RunLayerContext &context); + unsigned int in_size, unsigned int out_size); #endif diff --git a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp index 7c2c995020..c1284b0a9c 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp @@ -11,66 +11,10 @@ * */ +#include #include namespace nntrainer { -std::string rotary_emb_cl_kernel_fp16 = R"( - #pragma OPENCL EXTENSION cl_khr_fp16 : enable -__kernel void rotary_emb_cl_fp16(__global half *input, - __global half *output, - __global float *freqs_cos, - __global float *freqs_sin, - __global float *cos_, - __global float *sin_, - unsigned int batch, - unsigned int channel, - unsigned int height, - unsigned int width, - unsigned int dim, - unsigned int half_, - unsigned int max_timestep, - unsigned int from) { - __global float *cos_ptr = cos_; - __global float *sin_ptr = sin_; - - float value = 0.0f; - float transformed_value = 0.0f; - - unsigned int b = get_global_id(0); - unsigned int c = get_global_id(1); - - if(b < batch && c < channel){ - for (unsigned int h = 0; h < height; h++) { - if (from + h < max_timestep) { - unsigned idx = (from + h)*dim; - for(int i = idx; i < idx + dim; i++ ){ - cos_ptr[i - idx] = freqs_cos[i]; - sin_ptr[i - idx] = freqs_sin[i]; - } - } - - for (unsigned int w = 0; w < width; w = w + dim) { - for (unsigned int k = 0; k < dim; k++) { - unsigned int span = w + k; - value = (float)input[b * channel * height * width + c * height * width + h * width + span]; - if (k < half_) { - transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_]; - } else { - transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_]; - } - value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; - output[b * channel * height * width + c * height * width + h * width + span] = (half)value; - } - } - } - } -} -)"; - -/** - * @brief defining global kernel objects - */ -opencl::Kernel kernel_rotary_emb_fp16; void rotary_emb_cl(__fp16 *in, __fp16 *out, std::vector> freqs_cos, @@ -79,16 +23,14 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, unsigned int batch, unsigned int channel, unsigned int height, unsigned int width, unsigned int dim, unsigned int from, unsigned int max_timestep, - unsigned int in_size, unsigned int out_size, - RunLayerContext &context) { + unsigned int in_size, unsigned int out_size) { bool result = false; do { - result = context.clCreateKernel(rotary_emb_cl_kernel_fp16, - context.LayerKernel::ROTARY_EMB_FP16, - kernel_rotary_emb_fp16); - if (!result) { - printf("Failed to create kernel for rotary_emb_cl\n"); + ClContext::SharedPtrClKernel kernel_rotaryEmb_fp16_ptr = + cl_context_ref.registerClKernel(rotary_emb_cl_kernel_fp16_, + "rotary_emb_cl_fp16"); + if (!kernel_rotaryEmb_fp16_ptr) { break; } @@ -104,18 +46,22 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, size_t dim5_size = sizeof(float) * freqs_cos_dim * dim; size_t dim6_size = sizeof(float) * freqs_sin_dim * dim; - opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); + opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true, + nullptr); - opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr); + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim2_size, true, + nullptr); - opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr); + opencl::Buffer cosBuf(cl_context_ref.context_inst_, dim3_size, true, + nullptr); - opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr); + opencl::Buffer sinBuf(cl_context_ref.context_inst_, dim4_size, true, + nullptr); - opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true, + opencl::Buffer freqs_cosBuf(cl_context_ref.context_inst_, dim5_size, true, nullptr); - opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true, + opencl::Buffer freqs_sinBuf(cl_context_ref.context_inst_, dim6_size, true, nullptr); std::vector freqs_cos_flat; @@ -127,131 +73,137 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end()); } - result = inputA.WriteData(context.command_queue_inst_, in); + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); if (!result) { printf("Failed to write input data\n"); break; } - result = inOutRes.WriteData(context.command_queue_inst_, out); + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, out); if (!result) { printf("Failed to write output data\n"); break; } - result = freqs_cosBuf.WriteData(context.command_queue_inst_, + result = freqs_cosBuf.WriteData(cl_context_ref.command_queue_inst_, freqs_cos_flat.data()); if (!result) { printf("Failed to write freqs cos data\n"); break; } - result = freqs_sinBuf.WriteData(context.command_queue_inst_, + result = freqs_sinBuf.WriteData(cl_context_ref.command_queue_inst_, freqs_sin_flat.data()); if (!result) { printf("Failed to write freqs sin data\n"); break; } - result = cosBuf.WriteData(context.command_queue_inst_, cos_.data()); + result = cosBuf.WriteData(cl_context_ref.command_queue_inst_, cos_.data()); if (!result) { printf("Failed to write cos data\n"); break; } - result = sinBuf.WriteData(context.command_queue_inst_, sin_.data()); + result = sinBuf.WriteData(cl_context_ref.command_queue_inst_, sin_.data()); if (!result) { printf("Failed to write sin data\n"); break; } result = - kernel_rotary_emb_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); if (!result) { printf("Failed to set inputA argument\n"); break; } - result = - kernel_rotary_emb_fp16.SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(1, &inOutRes, + sizeof(cl_mem)); if (!result) { printf("Failed to set inOutRes argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(2, &freqs_cosBuf, - sizeof(cl_mem)); + result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(2, &freqs_cosBuf, + sizeof(cl_mem)); if (!result) { printf("Failed to set freqs_cosBuf argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(3, &freqs_sinBuf, - sizeof(cl_mem)); + result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(3, &freqs_sinBuf, + sizeof(cl_mem)); if (!result) { printf("Failed to set freqs_sinBuf argument\n"); break; } result = - kernel_rotary_emb_fp16.SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); if (!result) { printf("Failed to set cosBuf argument\n"); break; } result = - kernel_rotary_emb_fp16.SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); if (!result) { printf("Failed to set sinBuf argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(6, &batch, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(6, &batch, sizeof(int)); if (!result) { printf("Failed to set batch argument\n"); break; } result = - kernel_rotary_emb_fp16.SetKernelArguments(7, &channel, sizeof(int)); + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(7, &channel, sizeof(int)); if (!result) { printf("Failed to set channel argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(8, &height, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(8, &height, sizeof(int)); if (!result) { printf("Failed to set height argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(9, &width, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(9, &width, sizeof(int)); if (!result) { printf("Failed to set width argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(10, &dim, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(10, &dim, sizeof(int)); if (!result) { printf("Failed to set dim argument\n"); break; } unsigned int half_ = dim / 2; - result = kernel_rotary_emb_fp16.SetKernelArguments(11, &half_, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(11, &half_, sizeof(int)); if (!result) { printf("Failed to set half argument\n"); break; } - result = - kernel_rotary_emb_fp16.SetKernelArguments(12, &max_timestep, sizeof(int)); + result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(12, &max_timestep, + sizeof(int)); if (!result) { printf("Failed to set timestamp argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(13, &from, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(13, &from, sizeof(int)); if (!result) { printf("Failed to set from argument\n"); break; @@ -259,14 +211,14 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, const int work_groups_count[3] = {(int)batch, (int)channel, 1}; const int work_group_size[3] = {32, 32, 1}; // test-value - result = context.command_queue_inst_.DispatchCommand( - kernel_rotary_emb_fp16, work_groups_count, work_group_size); + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_rotaryEmb_fp16_ptr, work_groups_count, work_group_size); if (!result) { printf("Failed to dispatch command\n"); break; } - result = inOutRes.ReadData(context.command_queue_inst_, out); + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, out); if (!result) { printf("Failed to read data\n"); break; diff --git a/nntrainer/tensor/cl_operations/meson.build b/nntrainer/tensor/cl_operations/meson.build index 3f186ec645..a1b9b795bb 100644 --- a/nntrainer/tensor/cl_operations/meson.build +++ b/nntrainer/tensor/cl_operations/meson.build @@ -9,6 +9,7 @@ cl_op_headers = [ 'blas_kernel_interface.h', 'blas_kernel_strings.h', 'attention_kernel_interface.h', + 'attention_kernel_strings.h', ] if get_option('enable-fp16') diff --git a/test/unittest/unittest_attention_kernels_cl.cpp b/test/unittest/unittest_attention_kernels_cl.cpp index d2a26cc9d3..a95937446d 100644 --- a/test/unittest/unittest_attention_kernels_cl.cpp +++ b/test/unittest/unittest_attention_kernels_cl.cpp @@ -29,16 +29,13 @@ using namespace nntrainer; -static RunLayerContext setUpGpuContext() { - +static void setUpGpuContext() { auto &ac = nntrainer::ClContext::Global(); - auto rc = RunLayerContext(); - - return rc; + ac.initAttentionClKernels(); } TEST(attention_kernels, rotary_emb_kernel_FP32) { - RunLayerContext rc = setUpGpuContext(); + setUpGpuContext(); int batch = 1; int channel = 1; @@ -65,7 +62,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32) { B_fp32.copy(A_fp32); - apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc); + apply_rotary_emb_cl(A_fp32, dim, from, max_timestep); apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep); float mseErrorNeon_fp32 = @@ -81,7 +78,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32) { } TEST(attention_kernels, rotary_emb_kernel_FP32_case2) { - RunLayerContext rc = setUpGpuContext(); + setUpGpuContext(); int batch = 4; int channel = 4; @@ -108,7 +105,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32_case2) { B_fp32.copy(A_fp32); - apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc); + apply_rotary_emb_cl(A_fp32, dim, from, max_timestep); apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep); float mseErrorNeon_fp32 = @@ -124,7 +121,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32_case2) { } TEST(attention_kernels, rotary_emb_kernel_FP16) { - RunLayerContext rc = setUpGpuContext(); + setUpGpuContext(); int batch = 1; int channel = 1; @@ -150,7 +147,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP16) { B_fp16.copy(A_fp16); - apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc); + apply_rotary_emb_cl(A_fp16, dim, from, max_timestep); apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep); float mseErrorNeon_fp16 = mse<__fp16>( @@ -166,7 +163,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP16) { } TEST(attention_kernels, rotary_emb_kernel_FP16_case2) { - RunLayerContext rc = setUpGpuContext(); + setUpGpuContext(); int batch = 4; int channel = 4; @@ -192,7 +189,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP16_case2) { B_fp16.copy(A_fp16); - apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc); + apply_rotary_emb_cl(A_fp16, dim, from, max_timestep); apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep); float mseErrorNeon_fp16 = mse<__fp16>( From 2007b91a51fb7868858b3c37b05adcadff4b8d9c Mon Sep 17 00:00:00 2001 From: Yash Singh Date: Wed, 16 Oct 2024 12:43:01 +0530 Subject: [PATCH 6/6] [Trivial] Moved Testing Rotary Embedding to Unittest Dir Moved testing_rotary_emb.cpp to unittest Directory. Signed-off-by: Yash Singh --- .../unittest/testing_rotary_emb.cpp | 0 test/unittest/unittest_attention_kernels_cl.cpp | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp => test/unittest/testing_rotary_emb.cpp (100%) diff --git a/nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp b/test/unittest/testing_rotary_emb.cpp similarity index 100% rename from nntrainer/tensor/cl_operations/testing_rotarty_emb.cpp rename to test/unittest/testing_rotary_emb.cpp diff --git a/test/unittest/unittest_attention_kernels_cl.cpp b/test/unittest/unittest_attention_kernels_cl.cpp index a95937446d..dc3c337721 100644 --- a/test/unittest/unittest_attention_kernels_cl.cpp +++ b/test/unittest/unittest_attention_kernels_cl.cpp @@ -21,7 +21,7 @@ #include #include -#include "testing_rotarty_emb.cpp" +#include "testing_rotary_emb.cpp" #define EXPECT_IN_RANGE(VAL, MIN, MAX) \ EXPECT_GE((VAL), (MIN)); \