Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decomposition for ONNXSoftmaxCrossEntropyLossOp #2968

Merged
merged 10 commits into from
Nov 12, 2024

Conversation

srcarroll
Copy link
Contributor

No description provided.

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 7, 2024

i'm not really familiar with them so will have to look into it, but should i add to test/backend/inference_backend.py or do there already exists relevant tests for this op (i only see shape inference tests)? would such a test belong somewhere else?

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 7, 2024

also i'm not considering the ignore_index attribute here and will leave as a todo. but i don't know of a good way to match fail this case. any suggestions?

Edit: actually just realized it is an optional, not defaulted, attribute. so i know how to check. but if anyone has ideas on how to extend this to support it, i'm all ears. the simplest thing i can think of is to just subtract the slice at ignore_index from the result of ReduceSum.

Signed-off-by: Sam <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@AlexandreEichenberger
Copy link
Collaborator

@jenkins-droid test this please

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 8, 2024

i think i misunderstood the mean reduction with weights case. looking into it now (my recent "fix" commit is wrong)

@srcarroll
Copy link
Contributor Author

i think i misunderstood the mean reduction with weights case. looking into it now (my recent "fix" commit is wrong)

so i was just summing over the original weights tensor, but i should have been summing over W[n][d1][d2]...[dk] = weights[labels[i][d1][d2]...[dk]]. The simplest way i can think of producing W is doing onnx.Einsum(one_hot_labels, weights) {equation "ij...,j->i...} (or just a matmul if input is 2D). Any objection to this or any other suggestions? It probably wouldn't be very performant as one_hot_labels is a sparse matrix, but other solutions i can think of involve nasty indexing that i would like to avoid unless insisted upon.

@AlexandreEichenberger
Copy link
Collaborator

@jenkins-droid test this please

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@srcarroll
Copy link
Contributor Author

@jenkins-droid test this please

@AlexandreEichenberger you might want to cancel and rerun since i just pushed. also is there a rule of thumb for running tests? since this is still in draft i might push commits that shouldn't be tested yet. maybe i should just not have this drafted yet? sorry if i'm not following some etiquette for this

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@AlexandreEichenberger
Copy link
Collaborator

@jenkins-droid test this please

1 similar comment
@AlexandreEichenberger
Copy link
Collaborator

@jenkins-droid test this please

@srcarroll
Copy link
Contributor Author

@AlexandreEichenberger should i undraft to get some feedback on the testing i asked about? I figured i should get feedback first but i don't know of a good person to ping for this.

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@srcarroll srcarroll marked this pull request as ready for review October 16, 2024 23:44
@AlexandreEichenberger
Copy link
Collaborator

@jenkins-droid test this please

@AlexandreEichenberger
Copy link
Collaborator

since this is still in draft i might push commits that shouldn't be tested yet. maybe i should just not have this drafted yet? sorry if i'm not following some etiquette for this

We just run the CIs each time. I believe once a PR has been created, then it will run the CIs. If you just push to GitHub without creating a PR, then I think no CIs are run.

I would not worry about it, just do your work, and the CIs will help you along the way to see the progress of your code.

I also asked to add you on the list of approved folks that can run our CIs as you seem quite active and doing solid work. Thanks for your contributions, much appreciated.

@srcarroll
Copy link
Contributor Author

I would not worry about it, just do your work, and the CIs will help you along the way to see the progress of your code.

I also asked to add you on the list of approved folks that can run our CIs as you seem quite active and doing solid work. Thanks for your contributions, much appreciated.

thanks for the response and kind words. that would be great

@AlexandreEichenberger
Copy link
Collaborator

@srcarroll do you have someone in your org that could review the HLO specifics of your PRs?
I think that would help to have a second set of eyes on the HLO part. We can look at the more mundane mlir/onnx-mlir stuff for sure.

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 18, 2024

@srcarroll do you have someone in your org that could review the HLO specifics of your PRs? I think that would help to have a second set of eyes on the HLO part. We can look at the more mundane mlir/onnx-mlir stuff for sure.

Unfortunately not. I work at a very small startup and I'm pretty much the only contributor to mlir projects at the moment. @brnorris03 does contribute updates regulary and has worked on some onnx conversion, so she knows enough MLIR to review that aspect. When you say the HLO specifics, do you mean how it effects conversion to stablehlo? I have some knowledge of the stablehlo dialect and have contributed some changes to that project, but I wouldn't call myself an expert in that area. Otherwise I'm honestly unsure how my PRs are related to HLO.

@AlexandreEichenberger
Copy link
Collaborator

@srcarroll will merge once we get a free run; please let me know if you want to wait before merging.

Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@srcarroll
Copy link
Contributor Author

@AlexandreEichenberger thanks! i'm fine with merging when you are. i just wasn't sure if i should add some semantics tests somewhere. for what it's worth, i have my own test suite (not upstreamable) and have confirmed i get correct results from this decomposition

@srcarroll
Copy link
Contributor Author

Actually looks like I forgot to add a check for the ignore_index attribute. I don't handle that case, so should match fail. Let me update that real quick

@srcarroll
Copy link
Contributor Author

@AlexandreEichenberger maybe you can help me understand something. I just noticed that i get the ignore_index attribute, from the onnx->onnx-mlir conversion, even when not wanted. It seems like it is always set to ignore_index=-100. Since onnx allows negative indices, there's no way to tell if ignore_index is actually desired or if it was just set to some default to ignore. Any thoughts?

@AlexandreEichenberger
Copy link
Collaborator

AlexandreEichenberger commented Nov 12, 2024

@srcarroll that is something that I also do often; offline verification, and then I just add a lit test that make sure that the overall step remains the same as when I did the offline verification.

If you need help creating a lit test, here is what I do. I create a small mlir file with my test, put the proper RUN command then copy dummy test code, making sure that the function name matches, e.g.

// CHECK-DAG: #map = affine_map<(d0, d1, d2, d3) -> (d0, d3 floordiv 64, d1, d2 floordiv 32, d2 mod 32, d3 mod 64)>
// CHECK-DAG: #map1 = affine_map<(d0) -> (d0)>
// CHECK-LABEL:  func.func @test_onnx_sqrt_ztensor
// CHECK-SAME:   ([[PARAM_0_:%.+]]: memref<?x3x5x7xf16, #map>) -> memref<?x3x5x7xf16, #map> {
// CHECK:           [[VAR_c0_:%.+]] = arith.constant 0 : index
// CHECK:           [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref<?x3x5x7xf16, #map>

Or you can add the new test case to an existing test file, even easier, and just make sure you have a CHECK that will fails. (Note also that multiple consecutive tests must be separated by a // --- type line.

then I run fixLitTest.py <file> -r > bibi.mlir

It will test the file, and fix all of the individual tests that fails and pipe them to the std out. Then I diff the original file and the piped output... make sure that the new stuff looks good. When I am satisfied, then I just replace the old file by the new file.

Hope this helps.

@srcarroll
Copy link
Contributor Author

srcarroll commented Nov 12, 2024

thanks. i actually have already added tests to onnx_decompose.mlir

@srcarroll
Copy link
Contributor Author

srcarroll commented Nov 12, 2024

@AlexandreEichenberger maybe you can help me understand something. I just noticed that i get the ignore_index attribute, from the onnx->onnx-mlir conversion, even when not wanted. It seems like it is always set to ignore_index=-100. Since onnx allows negative indices, there's no way to tell if ignore_index is actually desired or if it was just set to some default to ignore. Any thoughts?

actually, looks like it doesn't appear from the conversion. it actually shows up in the onnx graph exported from torch. for example

graph main_graph (
  %input_0[FLOAT, 64x200]
  %input_1[INT64, 64]
) {
  %2 = SoftmaxCrossEntropyLoss[ignore_index = -100, reduction = 'sum'](%input_0, %input_1)
  return %2
}

was generated from

torch.nn.functional.cross_entropy(data, target, reduction="sum")

This seems wrong to me. Am i missing something?

Update: Ah I just realized this isn't an onnx thing, it's a torch thing. Doesn't make sense that they would default to -100 though. unless valid values are supposed to be positive, this option is too vague as far as i can tell.

@AlexandreEichenberger AlexandreEichenberger merged commit cb9a949 into onnx:main Nov 12, 2024
7 checks passed
@AlexandreEichenberger
Copy link
Collaborator

Thanks for your contributions.

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15002 [push] Add decomposition for `O... started at 13:54

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #15972 [push] Add decomposition for `O... started at 12:39

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #15975 [push] Add decomposition for `O... started at 13:39

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #15972 [push] Add decomposition for `O... passed after 1 hr 20 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #15975 [push] Add decomposition for `O... passed after 1 hr 25 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15002 [push] Add decomposition for `O... passed after 2 hr 21 min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants