-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
considerably speed up CPU matmul #411
Conversation
Would be nice to setup a dev\CPU area for CPU kernels just like we have for CUDA specifically for cases like this. |
@ngc92 I don't see the same performance improvements as you do.
MSVC compiler does a pretty good job optimizing the matmul_forward_slow --- Analyzing function: matmul_forward_slow --- Analyzing function: matmul_forward Microsoft's compiler messages for loop vectorization https://learn.microsoft.com/en-us/cpp/error-messages/tool-errors/vectorizer-and-parallelizer-messages?view=msvc-170 |
This seems to indicate more that MSVC is doing a terrible job with the second loop, not a particularly good job at the first loop. GCC produces the following beautiful assembly: .L11:
vmovups ymm0, YMMWORD PTR [rcx+rax]
vfmadd231ps ymm1, ymm0, YMMWORD PTR [rsi+rax]
vfmadd231ps ymm2, ymm0, YMMWORD PTR [r8+rax]
vfmadd231ps ymm3, ymm0, YMMWORD PTR [r10+rax]
vfmadd231ps ymm4, ymm0, YMMWORD PTR [r14+rax]
vfmadd231ps ymm5, ymm0, YMMWORD PTR [r13+0+rax]
vfmadd231ps ymm6, ymm0, YMMWORD PTR [r12+rax]
vfmadd231ps ymm7, ymm0, YMMWORD PTR [rbx+rax]
vfmadd231ps ymm8, ymm0, YMMWORD PTR [r11+rax]
add rax, 32
cmp r9, rax
jne .L11 This relies critically on `-Ofast', though, as it requires reordering floating-point summation. Could this be just because $LL13@matmul_for:
lea eax, DWORD PTR [r14+r10]
movsxd rcx, eax
movsxd rax, r10d
vmovss xmm0, DWORD PTR [r9+rcx*4]
vfmadd231ss xmm1, xmm0, DWORD PTR [rdi+rax*4]
lea eax, DWORD PTR [r10+rdx]
movsxd rcx, eax
lea eax, DWORD PTR [r15+r10]
vfmadd231ss xmm2, xmm0, DWORD PTR [rdi+rcx*4]
movsxd rcx, eax
lea eax, DWORD PTR [r12+r10]
vfmadd231ss xmm3, xmm0, DWORD PTR [rdi+rcx*4]
movsxd rcx, eax
lea eax, DWORD PTR [r10+rbp]
vfmadd231ss xmm4, xmm0, DWORD PTR [rdi+rcx*4]
movsxd rcx, eax
lea eax, DWORD PTR [r10+r13]
vfmadd231ss xmm5, xmm0, DWORD PTR [rdi+rcx*4]
movsxd rcx, eax
lea eax, DWORD PTR [r11+r10]
vfmadd231ss xmm6, xmm0, DWORD PTR [rdi+rcx*4]
movsxd rcx, eax
lea eax, DWORD PTR [rbx+r10]
inc r10d
vfmadd231ss xmm7, xmm0, DWORD PTR [rdi+rcx*4]
movsxd rcx, eax
vfmadd231ss xmm8, xmm0, DWORD PTR [rdi+rcx*4]
sub rsi, 1
jne SHORT $LL13@matmul_for not as nice as gccs code, but still decent enough. With 4 issue slots, the leas and movs shouldn't actually slow down the fmas. |
👍 |
@ngc92 Thanks for the link! Yeah. MSVC does a poor job with your version :(
|
It's still 3x slower than Pytorch... but it's a huge jump 😄 Better way to run benchmarks:
Benchmark:
Pytorch: #253 |
@azret |
Not sure how much we care about perf of the CPU version. This is trying to give a substantial boost, while still keeping complexity in check: without the comments, the matmul still fits on a single screen.
While it would have been possible to either just print an error message for "bad" shapes, or write a function that is more generic and handles the unrollable part followed by an epilogue loop, I think the approach of having a slow, but very simple version, and a fast, but shape-restricted version is better. It keeps the fast algorithm more readable, and having the slow version there as a reference has pedagogical value in itself.
I have also removed the helper row pointers (
out_bt
), because they make it more difficult to see the strides of the individual memory accesses.On my system, I observed about 20% end-to-end speedup (~5s to ~4s per step) with this change.