Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jun 30, 2023
1 parent a5ca028 commit 6417d6d
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 6 deletions.
66 changes: 61 additions & 5 deletions src/layer/vulkan/convolution_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,13 +666,41 @@ int Convolution_vulkan::create_pipeline(const Option& _opt)
// dst = pa-pb-kw-kh-inch/pa-outch/pb
if (opt.use_sgemm_convolution && !is_conv1x1s1d1 && num_input >= 16 && num_output >= 16)
{
bool use_cooperative_matrix_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 16 == 0 && num_output % 16 == 0;
bool use_cooperative_matrix = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 8 == 0 && num_output % 8 == 0;
if (use_cooperative_matrix)
if (use_cooperative_matrix_16_16)
{
// dst = 8b-8a-maxk-inch/8a-outch/8b
// dst = 16b-16a-maxk-inch/16a-outch/16b
Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output);

weight_data_packed.create(maxk * num_input / 16, num_output / 16, (size_t)4 * 16 * 16, 16 * 16);

for (int q = 0; q + 15 < num_output; q += 16)
{
float* g00 = weight_data_packed.row(q / 16);

for (int p = 0; p + 15 < num_input; p += 16)
{
for (int k = 0; k < maxk; k++)
{
for (int i = 0; i < 16; i++)
{
for (int j = 0; j < 16; j++)
{
const float* k00 = weight_data_r2.channel(q + j).row(p + i);
g00[0] = k00[k];
g00++;
}
}
}
}
}
}
else if (use_cooperative_matrix)
{
// dst = 8b-8a-maxk-inch/8a-outch/8b
Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output);

weight_data_packed.create(maxk * num_input / 8, num_output / 8, (size_t)4 * 8 * 8, 8 * 8);

for (int q = 0; q + 7 < num_output; q += 8)
Expand Down Expand Up @@ -830,8 +858,12 @@ int Convolution_vulkan::create_pipeline(const Option& _opt)
}
else if (opt.use_sgemm_convolution && !is_conv1x1s1d1 && num_input >= 16 && num_output >= 16)
{
bool use_cooperative_matrix_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 16 == 0 && num_output % 16 == 0;
bool use_cooperative_matrix = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 8 == 0 && num_output % 8 == 0;

NCNN_LOGE("use_cooperative_matrix_16_16 = %d", use_cooperative_matrix_16_16);
NCNN_LOGE("use_cooperative_matrix = %d", use_cooperative_matrix);

// check blob shape
if (!vkdev->shape_support_image_storage(shape_bordered_packed) || !vkdev->shape_support_image_storage(out_shape_packed))
{
Expand Down Expand Up @@ -885,13 +917,22 @@ int Convolution_vulkan::create_pipeline(const Option& _opt)
if (elempack == 4 && out_elempack == 8) shader_type_index = LayerShaderType::convolution_pack4to8_gemm;
if (elempack == 8 && out_elempack == 4) shader_type_index = LayerShaderType::convolution_pack8to4_gemm;

if (use_cooperative_matrix)
if (use_cooperative_matrix_16_16)
{
shader_type_index = LayerShaderType::convolution_pack4_gemm_cm_16_16_16;
}
else if (use_cooperative_matrix)
{
shader_type_index = LayerShaderType::convolution_pack4_gemm_cm_16_8_8;
}

pipeline_convolution_gemm = new Pipeline(vkdev);
if (use_cooperative_matrix)
if (use_cooperative_matrix_16_16)
{
// TODO proper unroll y
pipeline_convolution_gemm->set_local_size_xyz(32, 1, 1); // 16_16_16 ly*1
}
else if (use_cooperative_matrix)
{
// TODO proper unroll y
pipeline_convolution_gemm->set_local_size_xyz(32, 4, 1); // 16_8_8 ly*4
Expand Down Expand Up @@ -1476,8 +1517,12 @@ int Convolution_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCom
}
if (opt.use_sgemm_convolution && !is_conv1x1s1d1 && channels * elempack >= 16 && num_output >= 16)
{
bool use_cooperative_matrix_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && channels * elempack % 16 == 0 && num_output % 16 == 0;
bool use_cooperative_matrix = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && channels * elempack % 8 == 0 && num_output % 8 == 0;

NCNN_LOGE("use_cooperative_matrix_16_16 = %d", use_cooperative_matrix_16_16);
NCNN_LOGE("use_cooperative_matrix = %d", use_cooperative_matrix);

// gemm
top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
if (top_blob.empty())
Expand All @@ -1504,7 +1549,18 @@ int Convolution_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCom
dispatcher.h = top_blob.c;
dispatcher.c = 1;

if (use_cooperative_matrix)
if (use_cooperative_matrix_16_16)
{
// dispatcher.w = ((top_blob.w * top_blob.h + 15) / 16 + 3) / 4 * 32;
// dispatcher.h = (top_blob.c + 1) / 2;
// dispatcher.c = 1;
dispatcher.w = ((top_blob.w * top_blob.h + 15) / 16 + 1) / 2 * 32;
// dispatcher.w = (top_blob.w * top_blob.h + 15) / 16 * 32;
dispatcher.h = ((top_blob.c + 3) / 4 + 1) / 2;
// dispatcher.h = (top_blob.c + 3) / 4;
dispatcher.c = 1;
}
else if (use_cooperative_matrix)
{
dispatcher.w = ((top_blob.w * top_blob.h + 15) / 16 + 3) / 4 * 32;
dispatcher.h = (top_blob.c + 1) / 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void main()
sum3 = fcoopmatNV<32, gl_ScopeSubgroup, 16, 16>(0.f);
}

int N = psc(c) / 4;
const int N = psc(c) / 4;

int z = 0;
for (; z + (UNROLL_INCH - 1) < N; z += UNROLL_INCH)
Expand Down
Loading

0 comments on commit 6417d6d

Please sign in to comment.