Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable Materializing Grads #6822

Merged
merged 11 commits into from
Mar 8, 2021

Conversation

centwang
Copy link
Contributor

@centwang centwang commented Feb 26, 2021

Disable materializing grads in forward so that None object will not be converted to a tensor filled with zeros prior to calling backward. We will materialize the output grads by ourselves. If the output grad is None when calling backward, when it's the direct input of backward graph, we will materialize it to all-0 tensor with same shape of output, otherwise, we will use the scalar-0 tensor, because in this case there will be Add node added, and the scalar-0 tensor is always OK as one of inputs of the Add node.

This PR also contains a fix for module gradient builder. With the fix, we will not ignore any output grads from PT, if it's output of one node, we will add an Add node to add the node output with the grad tensor from PT.

Take the BERT large as example, change the user outputs to return all of loss, prediction_scores and seq_relationship_score, but call the backward from loss only. Without set_materialize_grads(False), the max batch size is 80 (on 32G V100), when set to False, the max batch size can reach to 88. Python debugger also confirms that when set to False, the output grads for prediction_scores and seq_relationship_score are None instead of all-zero tensors.

@SherlockNoMad
Copy link
Contributor

It's great that this worked!

SherlockNoMad
SherlockNoMad previously approved these changes Feb 26, 2021
@mrry
Copy link
Contributor

mrry commented Feb 27, 2021

This is a really nice saving, but can you explain what happens in the following corner case, starting from your example:

Take the BERT large as example, change the user outputs to return all of loss, prediction_scores and seq_relationship_score

Now we call backward() on a tensor that's derived from loss and prediction_scores, e.g. with this silly example:

loss, prediction_scores, _ = model.forward()
new_loss = loss + prediction_scores.sum()
new_loss.backward()

Do we do the right thing in this case? It doesn't look like this should change the control flow through ORTModule compared to calling loss.backward(), so I'd be surprised if it worked without more changes....

@centwang
Copy link
Contributor Author

centwang commented Mar 1, 2021

@mrry Thanks for your example. Actually for this case, both loss_grad and prediction_scores_grad will be all-1 tensors, and seq_relationship_score_grad will be None from PY. Set the flag to False will keep seq_relationship_score_grad as None instead of changing it to all-0 tensor, and it will work as current implementation needs loss_grad only.
Actually this reveal another big issue in our gradient graph builder, our current implementation is wrong for this case. Here since prediction_scores_grad is an output of a node in the graph, so during the graph builder, we ignore the all-1 prediction_scores_grad from PT. The right behavior should be to add a Sum node to sum the node output and the all-1 tensor from PT as the final prediction_scores_grad.
I modified the graph builder by adding the Sum node, but seems there are still some other issue during Yield op execution so the calculation result still not correct. I will go on debug this.

@mrry
Copy link
Contributor

mrry commented Mar 1, 2021

Thanks for digging into this @iK1D! There are probably some interesting complexity-efficiency tradeoffs to be made here, for example:

  1. (Easiest, least efficient) We could be pessimistic and always assume that any of the forward outputs could have a gradient, but I think that would require us to materialize the grads in every case, and thus we'd waste memory and compute.
  2. (Harder, more efficient) Since the set of forward outputs that have gradients is probably pretty stable from one run to the next, we could optimistically specialize the graph for the common case, and provide some way to fall back to (1) if we get unexpected gradients.
  3. (Hardest, most(?) efficient, and probably needs more thought!) We could disable materializing grads, as you have in this PR, and put some If nodes into the graph to deal with the None case. For example, each backward input could become a tuple of (bool, value), or some other representation that covers the fact that they are "optional". Then we could wrap the backwards computation for each input in the true branch of an If block, with a passthrough in the false branch. Depending on how efficient the control-flow implementation is, we might still want to specialize as in (2).

I'd prefer us to implement (1) in the first place, because we can always specialize the code manually by wrapping the original module in another module that drops the irrelevant outputs.

There are probably other approaches that could work here as well, but wanted to share a brain dump....

@SherlockNoMad
Copy link
Contributor

@mrry, The auxiliary loss in the MoE model is actually constructed in a similar way you shown above, so this is not an uncommon case. Sample here: https://aiinfra.visualstudio.com/Lotus/_git/MOE?path=%2Fmoe_module%2Fmoe.py&version=GBmain&line=33&lineEnd=34&lineStartColumn=1&lineEndColumn=1&lineStyle=plain&_a=contents

@iK1D, indeed there is a bug on how we currently construct the gradient graph.

Since we can't tell which module outputs would end up getting the gradient, I think we have two options:

  1. Always assume all outputs has gradients, and trim the graph in the first backward call().
  2. Defer building training graph to the backward() call

+@satyajandhyala for more perspectives, since he has also looked at this problem before.

@centwang
Copy link
Contributor Author

centwang commented Mar 2, 2021

I am thinking Derek's #1, but with some trick. We still set the materialize_grads to False, if it's None, we create a 0-scalar tensor as its grad and pass to ORT. If we can somehow make the SetOutputMLValue work (as the shape is inconsistent), I think the Sum is quick as we have scalar optimization for calculation.

@centwang centwang requested a review from a team as a code owner March 2, 2021 09:34
@centwang centwang dismissed SherlockNoMad’s stale review March 2, 2021 09:41

New changes need to be reviewed again.

if (!hasShape(*typeProto)) {
continue;
}
propagateShapeFromInputToOutput(ctx, j, i);
propagateShapeFromInputToOutput(ctx, i, i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape can be different, right? for scaler 0 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scalar-0 is just a workaround during execution. Logically the shape is same. Setting this here can help to infer other shape info in the backward graph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aishwarya found an issue in T5 that mem_plan reuse this scalar tensor. I think mem_plan should have some bug, but here I will not set the shape for outputs that might be scalar. Will check the mem_plan issue later.

loss.backward()

N, D_in, H, D_out = 32, 784, 500, 10
pt_model = NeuralNetMultiplePositionalArgumentsMultipleOutputs1(D_in, H, D_out).to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. better name, ...MultiOutputsWithDependency

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same applies to NeuralNetMultiplePositionalArgumentsMultipleOutputs0
...MultiOutputsWithoutDependency

Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall, it looks good to me. Thanks a lot for the fix!

@SherlockNoMad
Copy link
Contributor

SherlockNoMad commented Mar 4, 2021

with this PR, we should reach two assumptions that will reduce the complexity.

  1. Yield's inputs and outputs are always 1 to 1 mapping
  2. we no longer need to reorder the graph outputs, so Yield's input order should exactly match module's forward output order, and Yield's output order should match model's backward input order.

@@ -156,7 +158,10 @@ def assert_gradients_match_and_reset_gradient(ort_model, pt_model, reset_gradien
pt_name, pt_param = pt_named_param

assert pt_name in ort_name
assert torch.allclose(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol)
if pt_name in none_pt_params:
assert pt_param.grad is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. shall we also assert ort_param.grad as zeros?

run_step0(pt_model, pt_x1, pt_x2)
run_step0(ort_model, ort_x1, ort_x2)

assert torch.allclose(ort_x1.grad, pt_x1.grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also assert on the forward results?
I caught a mismatch on forward result in test_input_requires_grad_backward_creates_input_grad_as_required1
see
https://msdata.visualstudio.com/Vienna/_sprints/taskboard/ONNX%20Training/Vienna/Cobalt?workitem=1064208

run_step(pt_model, pt_x1, pt_x2)
run_step(ort_model, ort_x1, ort_x2)

_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also assert forward results?

"required_grad",
"The indices of the outputs that require gradient outputs.",
AttributeProto::INTS)
.Attr("full_shape_outputs", "The indices of the outputs that must have full shape.", AttributeProto::INTS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also add an assert in yield's kernel?
output shape should match input shape if "full_shape" is true.

Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments are non blocking, and they are mostly for defensive programming practice.

@centwang centwang merged commit 56c5620 into thiagofc/ortmodule-api Mar 8, 2021
@centwang centwang deleted the weicwang/disable_materialize_grads branch March 8, 2021 08:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants