-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OPT-1.3b performance tracker #1589
Comments
I also noticed for larger sequence lengths, a significant performance delta. Is this expected? e2e execution for sequence length 128 version of OPT takes about 9 times longer than the tiny seqlen8 version from above. Here is a tracy profile for the model executed end-to-end. To reproduce this trace:
Here is a screenshot of the dispatches from the profile statistics, ordered by total runtime: |
Do you have tuned results for dispatch 17 and 18? Looks like we should first boil down these ones. |
Tuning these dispatches for tile/workgroup sizes does not yield a significant performance delta, unfortunately. |
cc @MaheshRavishankar @hanhanW Current status: The sequence length = 8 case should see significant improvements from @bjacob 's work; see this discord thread We are discussing the matmul cases where M > 16, as the associated optimization space is drastically different from the narrow matmul cases. (ideally the existing codegen will yield acceptable performance with the right flags -- I will update this issue with confirmation on whether the 'wider' matmuls are performant with the correct flags.) |
thanks @monorimet for the summary; FTR, the above-linked discord thread also contains suggestions of flags to use for the M>=16 case. looking forward to hearing how that performs; we can discuss next steps from there. |
Here is an updated set of artifacts and tracy profile for the case where sequence length (M dimension) = 128, at fp16 precision. The .mlir can be found here: opt-1_3b_causallm_128_torch.mlir The .vmfb is located here: opt-1_3b_causallm_128_torch_cpu-task.vmfb The tracy profile is also provided here: opt128_noukernels.tracy Case 2 - with microkernels enabled: In this case, I was surprised to see a sharp decrease in performance. I assume that we are simply not in a case where this flag helps, but in case it seems wrong I wanted to share my results: The .vmfb is located here: opt-1_3b_causallm_128_torch_cpu-task_ukernels.vmfb The tracy profile is also provided here: opt128_ukernels.tracy |
The likely reason why microkernels decrease performance here is the Just checking - is it intentional that
|
I am also confused to see the f16 precision. Let me make sure this .mlir is fp32 precision and I will update accordingly. Edit: Yes, my mistake, this .mlir is in half-precision. I'll post again with the fp32 profiles in a few minutes. |
Running the same flags as above with the fp32 OPT.mlir results in a segfault in iree-benchmark-module. I will be removing flags to see if any specific ones are the culprit. Here are the commands I'm using where the segfaults occur (these worked in fp16):
The same occurs without the microkernels flag. I am noticing stack allocations requiring the |
It seems |
That is the flag that triggers everything. Without that flag, the other flags dont do anything. I will take a look, but it will take me sometime (a day or so) to get around to it. We need to track down the stack allocation issue as well. |
OK. In the meantime I will find out if the sequence length (M dimension in matmuls) is relevant to the segfaults-- 128 is a somewhat arbitrary choice so if we can find some other M>16 that works with data tiling then we can be at least temporarily unblocked. |
It seems the data tiling flag now causes this segfault for all sequence lengths. I will see what I can do to bisect the problem and, if possible, isolate the problematic dispatches. |
@monorimet I can take a look at this this week. I have to finish a few things before I can get to it. So I'll post here when I get to this. |
I was able to compile and run dispatch 25 with the following input:
Yielding the following results:
This behavior is replicated on sequence lengths 8, 16, 32, and 128 -- It isn't very surprising as the data tiling is happening at the flow level. I will back up to the flow level and do some more poking around, and try a few other cases tomorrow in f16 / compare with pytorch in the meantime -- please let me know if I can focus my efforts anywhere to help @MaheshRavishankar |
Could you just create an issue with the dispatch itself on IREE. If you see the IR after |
Also, |
OK, I have filed the issue in IREE. For the segfaults we can move discussion to that issue until resolved. |
I have adapted a script for running a perf comparison (SHARK/IREE vs. PyTorch) for opt1.3b causallm inference. Without data-tiling (with SHARK's default iree cpu flags -- I can dig these out if anyone is interested) we achieve the following:
|
@monorimet - With iree-org/iree#14398 now fixed, here are some benchmark results to give a flavor of performance to expect. Note - testing on a Intel Skylake-XEON CPU with AVX-512. Compiling with --iree-llvmcpu-target-cpu=skylake-avx512. Command lines as in the original PR description above.
So, data-tiling alone is a ~ 8x speedup. Ukernels alone are not yet good. But I'll get to that now, and it will be at least as fast as non-ukernels and in some cases faster. What's almost certainly happening here is that this particular model is f32, and f32 matmuls on ISAs like AVX-512 are what default codegen is good at. As soon as we depart from that, e.g. f16, things are more challenging for default codegen and the ukernels become more of a win. |
Nice! Are these the narrow shapes case? E.g., mostly I've seen ukernels doing rather well :) |
Oh I get it now - no it's not narrow (it's |
Whats the target here? Without ukernels but data tiling, this is at 72 us (maybe 515us is just so bad, that this is unfair comparison) |
I'll let Nod decide if there's a specific target; all I know is we still have plenty of room to run :-) |
The first case we wanted to meet/beat pytorch performance on was OPT1.3b in fp32 precision. With data tiling, I do see a significant improvement in e2e execution time (tested on this iree SHA) The following are e2e benchmarks on OPT-1.3b in fp32 precision, at sequence length 128, with avx512 instructions enabled. I will preface each result with the reproduction commands. Link to opt_1-3b_causallm_128_torch.mlir Case 1: No data tiling, no microkernels
Case 2: Data tiling, no ukernels:
Case 3: Data tiling and ukernels:
So in my case the tiled + ukernels mode seems to produce the best results. Examining performance delta vs. pytorch is a bit tricky -- we have to feed an input of 127 tokens for the pytorch model to be comparable to our sequence length 128 model. Evidently padding with the tokenizer doesn't seem to stop PyTorch from using the smallest possible model for optimal performance. This is generally equivalent in behavior to the dynamic path in torch-mlir/IREE stack. Since we are looking at sequence length 128, I've just run the performance comparison with excerpts of the declaration of independence with 127 words each, for 5 iterations:
4.24 vs. 3.9 is really quite good! Is there anything I've missed in latest IREE, from these reproducers, that could push us past pytorch performance? |
iree-org/iree#13822 might help a bit |
This is because this is a When the data type is not When the shapes are narrow (when the sequence length is small), codegen tends to adapt gracefully, but microkernels don't currently have fast code for narrow cases, so that's the other thing I want to fix very soon in microkernels. Further performance gains beyond that point will come from:
So I understand correctly this log: |
Thanks. I am building with tracy on latest IREE to get a trace of the tiled ukernels case, so I will share results with |
WIth latest ToT that flag is on by default. |
Yes, that is the value we get from pytorch runtime. |
Sorry to say I seem to get better performance with
With debug iree-runtime builds, the results are a bit more sporadic so the two cases seem quite similar (I can share those numbers if desired) Tracy profile for the fastest configuration (latest IREE, seqlen 128, fp32, avx512) (link) |
Are the matrix multiplications in this model involving a matrix of constant weights (as is the case in many NN inference workloads) as opposed to runtime values being multiplied by runtime values (as is the case in some recent NN architectures like transformers) ? If some matmul operands are constant data, then the corresponding set_encoding dispatches are running on constant data and are prime candidates for being constant-evaluated ( |
Hmm yes it really is lots of big constant matrices (just looking at the func.func @forward(%arg0: tensor<1x128xi64>, %arg1: tensor<1x128xi64>) -> tensor<1x128x50272xf16> {
%cst = arith.constant dense_resource<__elided__> : tensor<2048xf16>
%cst_0 = arith.constant dense_resource<__elided__> : tensor<2048xf16>
%cst_1 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf16>
%cst_2 = arith.constant dense_resource<__elided__> : tensor<8192xf16>
%cst_3 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf16>
%cst_4 = arith.constant dense_resource<__elided__> : tensor<2048xf16>
%cst_5 = arith.constant dense_resource<__elided__> : tensor<2048xf16> For example,
and that in turn becomes an operand to
So when IREE data-tiles that matmul, it creates a |
Is the flag all that's necessary to constant-evaluate the set_encoding dispatches?
|
hum, nothing out of the top of my head. need to look into this. |
There are a couple of things we need to do to get the const eval to work here. Its not a simple flag flip. One we need a way for the |
@MaheshRavishankar thanks for the explanation, if you file an issue with a ~ 4x expanded version of that to get me started, i might be able to try. I realized meanwhile that we also needed |
I dont know what the full solution is, but definitely worth starting an issue and describing what I know, and getting Ben/Stella's help on the remaining. Stay tuned.
|
There is already an issue for this iree-org/iree#11360. Ill add some things there |
For OPT-1.3b (fp32) we would like to burn down performance at the dispatch level.
Here is a tracy profile for the model executed end-to-end.
To reproduce this trace:
Here is a screenshot of the dispatches from the profile statistics, ordered by total runtime:
The text was updated successfully, but these errors were encountered: