-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disable Materializing Grads #6822
Disable Materializing Grads #6822
Conversation
It's great that this worked! |
This is a really nice saving, but can you explain what happens in the following corner case, starting from your example:
Now we call loss, prediction_scores, _ = model.forward()
new_loss = loss + prediction_scores.sum()
new_loss.backward() Do we do the right thing in this case? It doesn't look like this should change the control flow through |
@mrry Thanks for your example. Actually for this case, both loss_grad and prediction_scores_grad will be all-1 tensors, and seq_relationship_score_grad will be None from PY. Set the flag to False will keep seq_relationship_score_grad as None instead of changing it to all-0 tensor, and it will work as current implementation needs loss_grad only. |
Thanks for digging into this @iK1D! There are probably some interesting complexity-efficiency tradeoffs to be made here, for example:
I'd prefer us to implement (1) in the first place, because we can always specialize the code manually by wrapping the original module in another module that drops the irrelevant outputs. There are probably other approaches that could work here as well, but wanted to share a brain dump.... |
@mrry, The auxiliary loss in the MoE model is actually constructed in a similar way you shown above, so this is not an uncommon case. Sample here: https://aiinfra.visualstudio.com/Lotus/_git/MOE?path=%2Fmoe_module%2Fmoe.py&version=GBmain&line=33&lineEnd=34&lineStartColumn=1&lineEndColumn=1&lineStyle=plain&_a=contents @iK1D, indeed there is a bug on how we currently construct the gradient graph. Since we can't tell which module outputs would end up getting the gradient, I think we have two options:
+@satyajandhyala for more perspectives, since he has also looked at this problem before. |
I am thinking Derek's #1, but with some trick. We still set the materialize_grads to False, if it's None, we create a 0-scalar tensor as its grad and pass to ORT. If we can somehow make the SetOutputMLValue work (as the shape is inconsistent), I think the Sum is quick as we have scalar optimization for calculation. |
New changes need to be reviewed again.
if (!hasShape(*typeProto)) { | ||
continue; | ||
} | ||
propagateShapeFromInputToOutput(ctx, j, i); | ||
propagateShapeFromInputToOutput(ctx, i, i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape can be different, right? for scaler 0 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scalar-0 is just a workaround during execution. Logically the shape is same. Setting this here can help to infer other shape info in the backward graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aishwarya found an issue in T5 that mem_plan reuse this scalar tensor. I think mem_plan should have some bug, but here I will not set the shape for outputs that might be scalar. Will check the mem_plan issue later.
orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
Outdated
Show resolved
Hide resolved
loss.backward() | ||
|
||
N, D_in, H, D_out = 32, 784, 500, 10 | ||
pt_model = NeuralNetMultiplePositionalArgumentsMultipleOutputs1(D_in, H, D_out).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. better name, ...MultiOutputsWithDependency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same applies to NeuralNetMultiplePositionalArgumentsMultipleOutputs0
...MultiOutputsWithoutDependency
orttraining/orttraining/core/framework/module_gradient_graph_builder.h
Outdated
Show resolved
Hide resolved
orttraining/orttraining/core/framework/module_gradient_graph_builder.cc
Outdated
Show resolved
Hide resolved
orttraining/orttraining/core/framework/module_gradient_graph_builder.cc
Outdated
Show resolved
Hide resolved
orttraining/orttraining/core/framework/module_gradient_graph_builder.cc
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall, it looks good to me. Thanks a lot for the fix!
with this PR, we should reach two assumptions that will reduce the complexity.
|
@@ -156,7 +158,10 @@ def assert_gradients_match_and_reset_gradient(ort_model, pt_model, reset_gradien | |||
pt_name, pt_param = pt_named_param | |||
|
|||
assert pt_name in ort_name | |||
assert torch.allclose(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol) | |||
if pt_name in none_pt_params: | |||
assert pt_param.grad is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. shall we also assert ort_param.grad as zeros?
run_step0(pt_model, pt_x1, pt_x2) | ||
run_step0(ort_model, ort_x1, ort_x2) | ||
|
||
assert torch.allclose(ort_x1.grad, pt_x1.grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's also assert on the forward results?
I caught a mismatch on forward result in test_input_requires_grad_backward_creates_input_grad_as_required1
see
https://msdata.visualstudio.com/Vienna/_sprints/taskboard/ONNX%20Training/Vienna/Cobalt?workitem=1064208
run_step(pt_model, pt_x1, pt_x2) | ||
run_step(ort_model, ort_x1, ort_x2) | ||
|
||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also assert forward results?
"required_grad", | ||
"The indices of the outputs that require gradient outputs.", | ||
AttributeProto::INTS) | ||
.Attr("full_shape_outputs", "The indices of the outputs that must have full shape.", AttributeProto::INTS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also add an assert in yield's kernel?
output shape should match input shape if "full_shape" is true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remaining comments are non blocking, and they are mostly for defensive programming practice.
Disable materializing grads in forward so that None object will not be converted to a tensor filled with zeros prior to calling backward. We will materialize the output grads by ourselves. If the output grad is None when calling backward, when it's the direct input of backward graph, we will materialize it to all-0 tensor with same shape of output, otherwise, we will use the scalar-0 tensor, because in this case there will be Add node added, and the scalar-0 tensor is always OK as one of inputs of the Add node.
This PR also contains a fix for module gradient builder. With the fix, we will not ignore any output grads from PT, if it's output of one node, we will add an Add node to add the node output with the grad tensor from PT.
Take the BERT large as example, change the user outputs to return all of loss, prediction_scores and seq_relationship_score, but call the backward from loss only. Without set_materialize_grads(False), the max batch size is 80 (on 32G V100), when set to False, the max batch size can reach to 88. Python debugger also confirms that when set to False, the output grads for prediction_scores and seq_relationship_score are None instead of all-zero tensors.