forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pytorch-vulkan][2/n] Height packing (pytorch#113883)
Summary: Enable logic for converting a channel packed tensor into heigh packed one. Not yet connecting with rest of the system yet. Test Plan: ``` (base) yipjustin@yipjustin-mac fbsource % buck2 run -c pt.has_backtraces=1 --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_quantized_api_test_binAppleMac\#macosx-arm64 -- --gtest_filter="*packing*" File changed: fbsource//xplat/caffe2/aten/src/ATen/test/vulkan_quantized_api_test.cpp Buck UI: https://www.internalfb.com/buck2/9a0d6bd6-e4a2-4d58-8f38-f806a0703122 Network: Up: 0B Down: 0B Jobs completed: 4. Time elapsed: 0.1s. BUILD SUCCEEDED Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc Note: Google Test filter = *packing* [==========] Running 1 test from 1 test suite. [----------] Global test environment set-up. [----------] 1 test from VulkanAPITest [ RUN ] VulkanAPITest.channel_to_height_packing_test [ OK ] VulkanAPITest.channel_to_height_packing_test (35 ms) [----------] 1 test from VulkanAPITest (35 ms total) [----------] Global test environment tear-down [==========] 1 test from 1 test suite ran. (36 ms total) [ PASSED ] 1 test. ``` Reviewed By: SS-JIA Differential Revision: D51379737 Pull Request resolved: pytorch#113883 Approved by: https://github.com/SS-JIA
- Loading branch information
1 parent
fdaddec
commit f8516ce
Showing
10 changed files
with
464 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
aten/src/ATen/native/vulkan/glsl/convert_channels_to_height_packed.glsl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
#include "indexing.h" | ||
|
||
layout(std430) buffer; | ||
|
||
/* Qualifiers: layout - storage - precision - memory */ | ||
|
||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict image3D uOutput; | ||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
ivec4 sizes; | ||
} uBlock; | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
void main() { | ||
ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
int src_w = pos.x; | ||
int src_base_h = pos.y * 4; | ||
|
||
// uBlock.sizes.y is the c in nchw. | ||
int num_c = uBlock.sizes.y; | ||
|
||
int src_c = pos.z % num_c; | ||
int src_n = pos.z / num_c; | ||
|
||
// Fetch the 4 elements from the channel-packed tensor | ||
ivec4 src_pos0 = get_channel_packed_pos_from_index( | ||
ivec4(src_n, src_c, src_base_h, src_w), | ||
uBlock.sizes); | ||
|
||
ivec4 src_pos1 = get_channel_packed_pos_from_index( | ||
ivec4(src_n, src_c, src_base_h + 1, src_w), | ||
uBlock.sizes); | ||
|
||
ivec4 src_pos2 = get_channel_packed_pos_from_index( | ||
ivec4(src_n, src_c, src_base_h + 2, src_w), | ||
uBlock.sizes); | ||
|
||
ivec4 src_pos3 = get_channel_packed_pos_from_index( | ||
ivec4(src_n, src_c, src_base_h + 3, src_w), | ||
uBlock.sizes); | ||
|
||
vec4 t0 = texelFetch(uInput, src_pos0.xyz, 0); | ||
vec4 t1 = texelFetch(uInput, src_pos1.xyz, 0); | ||
vec4 t2 = texelFetch(uInput, src_pos2.xyz, 0); | ||
vec4 t3 = texelFetch(uInput, src_pos3.xyz, 0); | ||
|
||
vec4 out_t = vec4( | ||
t0[src_pos0.w], | ||
t1[src_pos1.w], | ||
t2[src_pos2.w], | ||
t3[src_pos3.w]); | ||
|
||
imageStore(uOutput, pos, out_t); | ||
} |
60 changes: 60 additions & 0 deletions
60
aten/src/ATen/native/vulkan/glsl/convert_channels_to_width_packed.glsl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
#include "indexing.h" | ||
|
||
layout(std430) buffer; | ||
|
||
/* Qualifiers: layout - storage - precision - memory */ | ||
|
||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict image3D uOutput; | ||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
ivec4 sizes; | ||
} uBlock; | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
void main() { | ||
ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
int src_base_w = pos.x * 4; | ||
int src_h = pos.y; | ||
|
||
// uBlock.sizes.y is the c in nchw. | ||
int num_c = uBlock.sizes.y; | ||
|
||
int src_c = pos.z % num_c; | ||
int src_n = pos.z / num_c; | ||
|
||
// Fetch the 4 elements from the channel-packed tensor | ||
ivec4 src_pos0 = get_channel_packed_pos_from_index( | ||
ivec4(src_n, src_c, src_h, src_base_w), | ||
uBlock.sizes); | ||
|
||
ivec4 src_pos1 = get_channel_packed_pos_from_index( | ||
ivec4(src_n, src_c, src_h, src_base_w + 1), | ||
uBlock.sizes); | ||
|
||
ivec4 src_pos2 = get_channel_packed_pos_from_index( | ||
ivec4(src_n, src_c, src_h, src_base_w + 2), | ||
uBlock.sizes); | ||
|
||
ivec4 src_pos3 = get_channel_packed_pos_from_index( | ||
ivec4(src_n, src_c, src_h, src_base_w + 3), | ||
uBlock.sizes); | ||
|
||
vec4 t0 = texelFetch(uInput, src_pos0.xyz, 0); | ||
vec4 t1 = texelFetch(uInput, src_pos1.xyz, 0); | ||
vec4 t2 = texelFetch(uInput, src_pos2.xyz, 0); | ||
vec4 t3 = texelFetch(uInput, src_pos3.xyz, 0); | ||
|
||
vec4 out_t = vec4( | ||
t0[src_pos0.w], | ||
t1[src_pos1.w], | ||
t2[src_pos2.w], | ||
t3[src_pos3.w]); | ||
|
||
imageStore(uOutput, pos, out_t); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.