Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gpu: intel: Optimize reusable layer normalization using work-group based reductions #1990

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 67 additions & 52 deletions src/gpu/intel/ocl/reusable_vectorized_lnorm.cl
Original file line number Diff line number Diff line change
Expand Up @@ -32,80 +32,95 @@ VEC_SUM_DEFINE(float2)
VEC_SUM_DEFINE(float4)
VEC_SUM_DEFINE(float8)

__attribute__((reqd_work_group_size(WG_SIZE, 1, 1)))
__attribute__((intel_reqd_sub_group_size(SG_SIZE))) __kernel void
lnorm_reusable_vectorized(__global SRC_DATA_T *src, __global float *mean,
__global float *variance, dim_t reduce_size, __global DST_DATA_T *dst,
__global WEI_DATA_T *scale, __global WEI_DATA_T *shift, float eps,
__global float *src_scale, __global float *dst_scale, int greads,
float rrs, dispatch_gws_rt_params_t gws_params) {
src = (GWS_GET_BUFFER_POS(SRC, gws_params, src)) - get_sub_group_local_id();
__global float *src_scale, __global float *dst_scale, float rrs,
dispatch_gws_rt_params_t gws_params) {
int sg_offset
= -(get_local_id(0)) + get_sub_group_id() * SG_SIZE * VECT_DT_N;
src = GWS_GET_BUFFER_POS(SRC, gws_params, src) + sg_offset;

mean = GWS_GET_BUFFER_POS(STAT, gws_params, mean);
variance = GWS_GET_BUFFER_POS(STAT, gws_params, variance);
FLT_ACC_DATA_T local_variance = CALCULATE_STATS ? 0.f : *variance;
FLT_ACC_DATA_T local_mean = CALCULATE_STATS ? 0.f : *mean;

#if PVT_MEM_SIZE > 1
VECT_FLOAT_T val[PVT_MEM_SIZE];
unroll_for_by(N_UNROLL)(int sg_idx = 0, i = 0; i < PVT_MEM_SIZE;
sg_idx += GROUP_STRIDE, i++) {
val[i] = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(VECT_BLOCK_READ(
(const __global BLOCK_DATA_T *)(&src[sg_idx]))));
}
#else
VECT_FLOAT_T val = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(
VECT_BLOCK_READ((const __global BLOCK_DATA_T *)(src))));
#endif
Comment on lines +51 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the compiler should be able to optimize this incantation. Give it a shot and let me know.

Suggested change
#if PVT_MEM_SIZE > 1
VECT_FLOAT_T val[PVT_MEM_SIZE];
unroll_for_by(N_UNROLL)(int sg_idx = 0, i = 0; i < PVT_MEM_SIZE;
sg_idx += GROUP_STRIDE, i++) {
val[i] = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(VECT_BLOCK_READ(
(const __global BLOCK_DATA_T *)(&src[sg_idx]))));
}
#else
VECT_FLOAT_T val = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(
VECT_BLOCK_READ((const __global BLOCK_DATA_T *)(src))));
#endif
VECT_FLOAT_T val[PVT_MEM_SIZE];
int sg_idx = 0;
for (int i = 0; i < PVT_MEM_SIZE; i++) {
val[i] = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(VECT_BLOCK_READ(
(const __global BLOCK_DATA_T *)(&src[sg_idx]))));
sg_idx += GROUP_STRIDE;
}


FLT_ACC_DATA_T local_variance = 0.f;
FLT_ACC_DATA_T local_mean = 0.f;
if (CALCULATE_STATS) {
/// Read global memory and mean and variance
FLT_ACC_DATA_T sum = 0;
unroll_for_by(N_UNROLL)(int sg_idx = 0; sg_idx < reduce_size;
sg_idx += SG_STRIDE) {
VECT_FLOAT_T val
= CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(VECT_BLOCK_READ(
(const __global BLOCK_DATA_T *)(&src[sg_idx]))));
sum += vec_sum(val);
VECT_FLOAT_T sum = 0;
#if PVT_MEM_SIZE > 1
unroll_for_by(N_UNROLL)(int i = 0; i < PVT_MEM_SIZE; i++) {
sum += val[i];
}
#else
sum = val;
#endif
local_mean = GROUP_ADD(vec_sum(sum)) * rrs;

local_mean = sub_group_reduce_add(sum) * rrs;
FLT_ACC_DATA_T sumsq = 0;
unroll_for_by(N_UNROLL)(int i = 0; i < greads; i++) {
VECT_FLOAT_T val;
val = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(VECT_BLOCK_READ((
const __global BLOCK_DATA_T *)(&src[i * SG_STRIDE]))))
- local_mean;
val *= val;
sumsq += vec_sum(val);
#if PVT_MEM_SIZE > 1
sum = 0;
unroll_for_by(N_UNROLL)(int i = 0; i < PVT_MEM_SIZE; i++) {
VECT_FLOAT_T var_val;
var_val = val[i] - local_mean;
var_val *= var_val;
sum += var_val;
}
local_variance = sub_group_reduce_add(sumsq) * rrs;
} else {
mean = GWS_GET_BUFFER_POS(STAT, gws_params, mean);
variance = GWS_GET_BUFFER_POS(STAT, gws_params, variance);
local_mean = *mean;
local_variance = *variance;
#else
sum = val - local_mean;
sum *= sum;
#endif
local_variance = GROUP_ADD(vec_sum(sum)) * rrs;
}

if (USE_SCALE)
scale = GWS_GET_BUFFER_POS(SS, gws_params, scale)
+ ((greads - 1) * SG_STRIDE);
scale = GWS_GET_BUFFER_POS(SS, gws_params, scale) + sg_offset;
if (USE_SHIFT)
shift = GWS_GET_BUFFER_POS(SS, gws_params, shift)
+ ((greads - 1) * SG_STRIDE);
shift = GWS_GET_BUFFER_POS(SS, gws_params, shift) + sg_offset;

/// Normalize layer
FLT_ACC_DATA_T sqrt_variance = rsqrt(local_variance + eps);
__global DST_DATA_T *dst_vect = (GWS_GET_BUFFER_POS(DST, gws_params, dst))
- get_sub_group_local_id() + ((greads - 1) * SG_STRIDE);
__global DST_DATA_T *dst_vect
= (GWS_GET_BUFFER_POS(DST, gws_params, dst)) + sg_offset;

float src_scale_val = src_scale ? *src_scale : 1.f;
float dst_scale_val = dst_scale ? native_recip(*dst_scale) : 1.f;

unroll_for_by(N_UNROLL)(int i = greads - 1; i >= 0; i--) {
VECT_FLOAT_T res
= CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(VECT_BLOCK_READ((
const __global BLOCK_DATA_T *)(&src[i * SG_STRIDE]))))
- local_mean;
res *= sqrt_variance;
if (USE_SCALE) res *= LOAD_VECT_WEI(scale);
if (USE_SHIFT) res += LOAD_VECT_WEI(shift);

res *= src_scale_val;
res *= dst_scale_val;
#if PVT_MEM_SIZE > 1
unroll_for_by(N_UNROLL)(int i = 0; i < PVT_MEM_SIZE; i++) {
#endif
VECT_FLOAT_T res;
#if PVT_MEM_SIZE > 1
res = val[i] - local_mean;
#else
res = val - local_mean;
#endif
VECT_FLOAT_T sc = (USE_SCALE) ? LOAD_VECT_WEI(scale) : 1.f;
VECT_FLOAT_T sh = (USE_SHIFT) ? LOAD_VECT_WEI(shift) : 0.f;
VECT_FLOAT_T out = (sc * res * sqrt_variance + sh) * src_scale_val
* dst_scale_val;

VECT_DST_BLOCK_WRITE(dst_vect, CONVERT_VECTOR_DST_DATA_T(res));
dst_vect -= SG_STRIDE;
if (USE_SCALE) scale -= SG_STRIDE;
if (USE_SHIFT) shift -= SG_STRIDE;
VECT_DST_BLOCK_WRITE(dst_vect, CONVERT_VECTOR_DST_DATA_T(out));
#if PVT_MEM_SIZE > 1
dst_vect += GROUP_STRIDE;
if (USE_SCALE) scale += GROUP_STRIDE;
if (USE_SHIFT) shift += GROUP_STRIDE;
}
if (SAVE_STATS && get_sub_group_local_id() == 0) {
mean = GWS_GET_BUFFER_POS(STAT, gws_params, mean);
variance = GWS_GET_BUFFER_POS(STAT, gws_params, variance);
#endif
if (SAVE_STATS && get_local_id(0) == 0) {
*mean = local_mean;
*variance = local_variance;
}
Expand Down
117 changes: 98 additions & 19 deletions src/gpu/intel/ocl/reusable_vectorized_lnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ static status_t init_conf_common(const layer_normalization_pd_t *pd,
const compute::named_buffer_t &output_buf,
const compute::named_buffer_t &stat_buf,
const compute::named_buffer_t &ss_buf) {

int max_unroll = gpu_utils::dev_getenv("lnorm_max_unroll", 12);

conf->use_scale = pd->use_scale();
conf->use_shift = pd->use_shift();
conf->input_dt = input_buf.data_type;
Expand Down Expand Up @@ -125,23 +128,22 @@ static status_t init_conf_common(const layer_normalization_pd_t *pd,
return status::unimplemented;
}

const auto *gpu_attr = utils::downcast<gpu_primitive_attr_t *>(
pd->attr()->gpu_attr_.get());

const auto *compute_engine
= utils::downcast<const compute::compute_engine_t *>(engine);
const auto *gpu_attr = utils::downcast<gpu_primitive_attr_t *>(
pd->attr()->gpu_attr_.get());

conf->sg_size = 0;
conf->vector_size = 0;
std::unique_ptr<lws_strategy_t> lws_strategy;
bool found_compatible_sg_and_vector_size = false;

for (int sg_size : {32, 16}) {
for (int vector_size : {8, 4, 2, 1}) {
bool sg_and_vector_size_ok = is_sg_and_vector_size_compatible(
compute_engine, sg_size, vector_size);
bool sg_stride_ok = is_sg_stride_compatible(
bool group_stride_ok = is_sg_stride_compatible(
pd->norm_axis(), sg_size * vector_size);

if (sg_and_vector_size_ok && sg_stride_ok) {
if (sg_and_vector_size_ok && group_stride_ok) {
conf->sg_size = sg_size;
conf->vector_size = vector_size;
found_compatible_sg_and_vector_size = true;
Expand All @@ -159,12 +161,78 @@ static status_t init_conf_common(const layer_normalization_pd_t *pd,
return status::unimplemented;
}

conf->unroll = std::min<int>(
4, (int)pd->norm_axis() / (conf->sg_size * conf->vector_size));

// Norm dispatch: all dimensions
auto lws_strategy = single_subgroup_lws_strategy_t(
compute_engine, gpu_attr, conf->sg_size);
auto di = compute_engine->device_info();

int num_sg_per_wg
= (int)pd->norm_axis() / (conf->sg_size * conf->vector_size);

float threads_launched_for_wg_kernel = num_sg_per_wg * pd->across_axis();
float threads_launched_for_sg_kernel = pd->across_axis();

float sg_waves_small_grf
= threads_launched_for_sg_kernel / di->hw_threads(false);
float wg_waves = threads_launched_for_wg_kernel / di->hw_threads(false);

size_t lnorm_bytes
= pd->norm_axis() * types::data_type_size(input_buf.data_type);
size_t tensor_size
= src_mdw.nelems() * types::data_type_size(input_buf.data_type);
size_t cache_size = compute_engine->device_info()->l3_cache_size();

bool sg_kernel_utilization_low = sg_waves_small_grf < 0.50f;
bool wg_has_higher_eu_utilization = wg_waves > sg_waves_small_grf;
bool wg_barrier_overhead_overcomes_unroll
= num_sg_per_wg > 3 || num_sg_per_wg > max_unroll;
bool wg_kernel_launch_configuration_within_device_bounds
= static_cast<size_t>(pd->norm_axis() / conf->vector_size)
< compute_engine->device_info()->max_wg_size(false);
bool src_tensor_doesnt_fit_in_cache
= (static_cast<float>(tensor_size) / cache_size) >= 0.75;
bool lnorm_axis_is_large = lnorm_bytes > 1280;

conf->select_work_group_kernel
= wg_kernel_launch_configuration_within_device_bounds
&& ((sg_kernel_utilization_low && wg_has_higher_eu_utilization
&& wg_barrier_overhead_overcomes_unroll)
|| (src_tensor_doesnt_fit_in_cache && lnorm_axis_is_large));

conf->select_work_group_kernel = gpu_utils::dev_getenv(
"lnorm_select_wg_kernel", conf->select_work_group_kernel);
if (conf->select_work_group_kernel) {
conf->unroll = 1;
conf->private_mem_size = 0;
conf->large_grf = false;
conf->wg_size = pd->norm_axis() / conf->vector_size;
lws_strategy.reset(
new default_lws_strategy_t(compute_engine, gpu_attr));
} else {
conf->unroll = std::min<int>(max_unroll,
(int)pd->norm_axis() / (conf->sg_size * conf->vector_size));

conf->wg_size = conf->sg_size;
conf->large_grf = conf->unroll > 5;

conf->private_mem_size
= pd->norm_axis() / (conf->sg_size * conf->vector_size);
lws_strategy.reset(new single_subgroup_lws_strategy_t(
compute_engine, gpu_attr, conf->sg_size));
}
conf->unroll = gpu_utils::dev_getenv("lnorm_unroll", conf->unroll);
conf->large_grf = gpu_utils::dev_getenv("lnorm_large_grf", conf->large_grf);

VDEBUGINFO(15, primitive, lnorm,
"%s: wg_kernel_launch_configuration_within_device_bounds(%d) && "
"((sg_kernel_utilization_low(%d) && "
"wg_has_higher_eu_utilization(%d) && "
"wg_barrier_overhead_overcomes_unroll(%d)) || "
"(src_tensor_doesnt_fit_in_cache(%d) && lnorm_axis_is_large(%d)) "
"):",
conf->select_work_group_kernel ? "work_group kernel"
: "sub_group_kernel",
wg_kernel_launch_configuration_within_device_bounds,
sg_kernel_utilization_low, wg_has_higher_eu_utilization,
wg_barrier_overhead_overcomes_unroll,
src_tensor_doesnt_fit_in_cache, lnorm_axis_is_large);

compute::reusable_dispatch_config_t dispatch_config(compute_engine, dims);
CHECK(dispatch_config.register_buffer(input_buf));
Expand All @@ -173,7 +241,7 @@ static status_t init_conf_common(const layer_normalization_pd_t *pd,
CHECK(dispatch_config.register_buffer(ss_buf));

compute::reusable_dispatch_t dispatch;
CHECK(dispatch_config.generate(dispatch, lws_strategy));
CHECK(dispatch_config.generate(dispatch, *lws_strategy));
conf->gws_params = dispatch.get_compile_params();
rt_conf->gws_params = dispatch.get_runtime_params();

Expand Down Expand Up @@ -207,6 +275,10 @@ compute::kernel_ctx_t
reusable_vectorized_lnorm_params_t::get_kernel_ctx() const {
compute::kernel_ctx_t kernel_ctx;
kernel_ctx.set_data_type(input_dt);
kernel_ctx.add_option("-cl-std=CL3.0");

if (large_grf) { kernel_ctx.add_option("-cl-intel-256-GRF-per-thread"); }

def_data_type(kernel_ctx, input_dt, "SRC");
def_data_type(kernel_ctx, ss_dt, "WEI");
def_data_type(kernel_ctx, output_dt, "DST");
Expand All @@ -219,8 +291,17 @@ reusable_vectorized_lnorm_params_t::get_kernel_ctx() const {

kernel_ctx.define_int("SG_SIZE", sg_size);
kernel_ctx.define_int("VECT_DT_N", vector_size);
kernel_ctx.define_int("SG_STRIDE", sg_size * vector_size);
kernel_ctx.define_int("N_UNROLL", unroll);
kernel_ctx.define_int("PVT_MEM_SIZE", private_mem_size);

kernel_ctx.define_int("WG_SIZE", wg_size);
if (select_work_group_kernel) {
kernel_ctx.add_option("-DGROUP_ADD=work_group_reduce_add");
kernel_ctx.define_int("GROUP_STRIDE", INT32_MAX);
} else {
kernel_ctx.add_option("-DGROUP_ADD=sub_group_reduce_add");
kernel_ctx.define_int("GROUP_STRIDE", sg_size * vector_size);
}

gws_params.def_kernel_macros(kernel_ctx);

Expand Down Expand Up @@ -258,17 +339,15 @@ status_t reusable_vectorized_layer_normalization_fwd_t::execute_forward(
lnorm_arg_list.append(pd()->desc()->layer_norm_epsilon);
lnorm_arg_list.append(src_scale);
lnorm_arg_list.append(dst_scale);
lnorm_arg_list.append((int)utils::div_up(
pd()->norm_axis(), conf.sg_size * conf.vector_size));
lnorm_arg_list.append(1.f / (pd()->norm_axis()));

lnorm_arg_list.append(rt_conf.gws_params.get());

compute::nd_range_t gws_nd_range_calc(
{static_cast<size_t>(conf.sg_size),
{static_cast<size_t>(conf.wg_size),
rt_conf.gws_params.nd_range.global_range().data()[1],
rt_conf.gws_params.nd_range.global_range().data()[2]},
{static_cast<size_t>(conf.sg_size), 1, 1});
{static_cast<size_t>(conf.wg_size), 1, 1});

return parallel_for(
ctx, gws_nd_range_calc, calculate_lnorm_kernel_, lnorm_arg_list);
Expand Down
19 changes: 16 additions & 3 deletions src/gpu/intel/ocl/reusable_vectorized_lnorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ struct reusable_vectorized_lnorm_params_t
compute::kernel_ctx_t get_kernel_ctx() const;

compute::dispatch_compile_params_t gws_params;

/// Number of work items in a sub-group
int sg_size;
uint32_t sg_size;

/// The number of work-items in the work group
uint32_t wg_size = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect you don't actually need this variable. It's used in two places:

  1. Set as a build option to get passed to reqd_work_group_size: You can probably just remove this and performance changes will be minimal.
  2. Computing the nd_range_t: Reconstruct it directly in the execute function (based on select_work_group_kernel, vector_size, and pd())

If you can remove this, the kernel will be far more reusable.


/// Number of elements to process in each work-item
int vector_size;
uint32_t vector_size;

/// The number of times the loops need to unroll
int unroll;
Expand All @@ -78,7 +82,16 @@ struct reusable_vectorized_lnorm_params_t
/// Saves the mean and variance to memory
bool save_stats = false;

uint8_t padding[4] = {false};
/// Select the work_group based reduction kernel
bool select_work_group_kernel = false;

/// Use the cl-intel-256-GRF-per-thread flag
bool large_grf = false;
Comment on lines +88 to +89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry about using large GRF mode as a heuristic. On most intel GPUs, switching the GRF mode requires stalling the pipeline which can lead to performance losses. You can (probably) see this by running a benchdnn batch on layers that get small/large/small GRF modes, and you should see performance much lower than when they're run separately.

Usually, the GRF mode is passed in by the user as a GPU attr, and the kernels are just tasked with sticking to it.


/// The number of elements to allocate in the val array in the kernel
uint8_t private_mem_size = 0;

uint8_t padding[5] = {false};
};

struct reusable_vectorized_lnorm_runtime_params_t {
Expand Down
Loading