Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SWISH] support Swish F32/F16 kernel #85

Merged
merged 6 commits into from
Oct 17, 2024
Merged

Conversation

wangzijian1010
Copy link
Contributor

No description provided.

@DefTruth
Copy link
Owner

LGTM ~

swish/README.md Outdated
```bash
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 relu.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: relu -> swish ?

swish/swish.cu Outdated
__global__ void swish_f16x8_kernel(half* x, half* y, int N) {
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
half2 reg_x_0 = HALF2(x[idx + 0]);
half2 reg_x_1 = HALF2(x[idx + 2]);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码风格,本仓库使用2空格作为缩进

swish/swish.cu Outdated
half pack_x[8], pack_y[8];
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]);

#pragma unroll
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pragma unroll和for对齐

@DefTruth DefTruth changed the title [SWISH][Half] support Swish kernel [SWISH] support Swish F32/F16 kernel Oct 17, 2024
@DefTruth DefTruth merged commit c4db4f8 into DefTruth:main Oct 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants