-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SWISH] support Swish F32/F16 kernel #85
Conversation
LGTM ~ |
swish/README.md
Outdated
```bash | ||
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... | ||
export TORCH_CUDA_ARCH_LIST=Ada | ||
python3 relu.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: relu -> swish ?
swish/swish.cu
Outdated
__global__ void swish_f16x8_kernel(half* x, half* y, int N) { | ||
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); | ||
half2 reg_x_0 = HALF2(x[idx + 0]); | ||
half2 reg_x_1 = HALF2(x[idx + 2]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码风格,本仓库使用2空格作为缩进
swish/swish.cu
Outdated
half pack_x[8], pack_y[8]; | ||
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); | ||
|
||
#pragma unroll |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pragma unroll和for对齐
No description provided.