包含以下内容:
- sigmoid_f32_kernel
- sigmoid_f32x4_kernel(float4向量化版本)
- sigmoid_f16_kernel
- sigmoid_f16x2_kernel(half2向量化)
- sigmoid_f16x8_kernel(unpack版本)
- sigmoid_f16x8_pack_kernel(pack版本)
- PyTorch bindings
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 sigmoid.py
输出:
-------------------------------------------------------------------------------------
S=1024, K=1024
out_f32: [0.68919814, 0.37103432], time:0.00563478ms
out_f32x4: [0.68919814, 0.37103432], time:0.00370646ms
out_f32_th: [0.68919814, 0.37103432], time:0.00576425ms
-------------------------------------------------------------------------------------
out_f16: [0.68896484, 0.37109375], time:0.00638676ms
out_f16x2: [0.68896484, 0.37109375], time:0.00354910ms
out_f16x8: [0.68896484, 0.37109375], time:0.00367451ms
out_f16x8pack: [0.68896484, 0.37109375], time:0.00370884ms
out_f16_th: [0.68896484, 0.37109375], time:0.00567985ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=2048
out_f32: [0.38874274, 0.44405776], time:0.00722098ms
out_f32x4: [0.38874274, 0.44405776], time:0.00494838ms
out_f32_th: [0.38874274, 0.44405776], time:0.00566697ms
-------------------------------------------------------------------------------------
out_f16: [0.38867188, 0.4440918], time:0.00784016ms
out_f16x2: [0.38867188, 0.4440918], time:0.00748515ms
out_f16x8: [0.38867188, 0.4440918], time:0.00568652ms
out_f16x8pack: [0.38867188, 0.4440918], time:0.00422883ms
out_f16_th: [0.38867188, 0.4440918], time:0.00566697ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=4096
out_f32: [0.59951311, 0.40737447], time:0.01265526ms
out_f32x4: [0.59951311, 0.40737447], time:0.00922751ms
out_f32_th: [0.59951305, 0.40737447], time:0.00842571ms
-------------------------------------------------------------------------------------
out_f16: [0.59960938, 0.40722656], time:0.01393223ms
out_f16x2: [0.59960938, 0.40722656], time:0.01035023ms
out_f16x8: [0.59960938, 0.40722656], time:0.00980020ms
out_f16x8pack: [0.59960938, 0.40722656], time:0.00676632ms
out_f16_th: [0.59960938, 0.4074707], time:0.00692296ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=1024
out_f32: [0.75156671, 0.7535817], time:0.00957298ms
out_f32x4: [0.75156671, 0.7535817], time:0.00495720ms
out_f32_th: [0.75156665, 0.7535817], time:0.00566721ms
-------------------------------------------------------------------------------------
out_f16: [0.75195312, 0.75341797], time:0.01125717ms
out_f16x2: [0.75195312, 0.75341797], time:0.00592303ms
out_f16x8: [0.75195312, 0.75341797], time:0.00506973ms
out_f16x8pack: [0.75195312, 0.75341797], time:0.00419903ms
out_f16_th: [0.75146484, 0.75341797], time:0.00567865ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=2048
out_f32: [0.63868785, 0.87350833], time:0.01265430ms
out_f32x4: [0.63868785, 0.87350833], time:0.00815344ms
out_f32_th: [0.63868785, 0.87350833], time:0.00842357ms
-------------------------------------------------------------------------------------
out_f16: [0.63867188, 0.87353516], time:0.01393294ms
out_f16x2: [0.63867188, 0.87353516], time:0.01325536ms
out_f16x8: [0.63867188, 0.87353516], time:0.00920200ms
out_f16x8pack: [0.63867188, 0.87353516], time:0.00658989ms
out_f16_th: [0.63867188, 0.87353516], time:0.00692558ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=4096
out_f32: [0.68829387, 0.86458832], time:0.02354550ms
out_f32x4: [0.68829387, 0.86458832], time:0.01642895ms
out_f32_th: [0.68829387, 0.86458838], time:0.01490569ms
-------------------------------------------------------------------------------------
out_f16: [0.68798828, 0.86474609], time:0.02601480ms
out_f16x2: [0.68798828, 0.86474609], time:0.01906967ms
out_f16x8: [0.68798828, 0.86474609], time:0.01742220ms
out_f16x8pack: [0.68798828, 0.86474609], time:0.01142383ms
out_f16_th: [0.68847656, 0.86474609], time:0.01189470ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=1024
out_f32: [0.73401731, 0.48190504], time:0.01741576ms
out_f32x4: [0.73401731, 0.48190504], time:0.00823951ms
out_f32_th: [0.73401731, 0.48190504], time:0.00842834ms
-------------------------------------------------------------------------------------
out_f16: [0.73388672, 0.48217773], time:0.02083182ms
out_f16x2: [0.73388672, 0.48217773], time:0.00901842ms
out_f16x8: [0.73388672, 0.48217773], time:0.00793672ms
out_f16x8pack: [0.73388672, 0.48217773], time:0.00658131ms
out_f16_th: [0.73388672, 0.48193359], time:0.00692177ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=2048
out_f32: [0.62901723, 0.40741974], time:0.02355123ms
out_f32x4: [0.62901723, 0.40741974], time:0.01470804ms
out_f32_th: [0.62901723, 0.40741974], time:0.01491427ms
-------------------------------------------------------------------------------------
out_f16: [0.62890625, 0.40722656], time:0.02601695ms
out_f16x2: [0.62890625, 0.40722656], time:0.02475381ms
out_f16x8: [0.62890625, 0.40722656], time:0.01649356ms
out_f16x8pack: [0.62890625, 0.40722656], time:0.01125836ms
out_f16_th: [0.62890625, 0.4074707], time:0.01189661ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=4096
out_f32: [0.61078298, 0.24497166], time:0.18634176ms
out_f32x4: [0.61078298, 0.24497166], time:0.18633699ms
out_f32_th: [0.61078292, 0.24497166], time:0.18860745ms
-------------------------------------------------------------------------------------
out_f16: [0.61083984, 0.24475098], time:0.05021238ms
out_f16x2: [0.61083984, 0.24475098], time:0.03640413ms
out_f16x8: [0.61083984, 0.24475098], time:0.03263068ms
out_f16x8pack: [0.61083984, 0.24475098], time:0.02065420ms
out_f16_th: [0.61083984, 0.24487305], time:0.02181959ms
-------------------------------------------------------------------------------------