-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add custom init grad for backward function #31540
Changes from 8 commits
d0915f8
0bccce6
5dac8e9
ef4c7b9
33b0416
837e26b
1901970
55e0cfb
8271dc0
5af3bd0
1467feb
eb267fa
b80f449
2bb8f3c
41b375f
6974e5c
1e3e975
c7de011
2f2824c
8415df4
be065e4
7f8e58c
0374c0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ DECLARE_bool(sort_sum_gradient); | |
namespace paddle { | ||
namespace imperative { | ||
|
||
void BasicEngine::Init(VarBase* var, bool retain_graph) { | ||
void BasicEngine::Init(VarBase* var, bool retain_graph, VarBase* grad_tensor) { | ||
retain_graph_ = retain_graph; | ||
init_node_ = var->GradVarBase()->GradNode(); | ||
PADDLE_ENFORCE_EQ(var->GradVarBase()->GraphIsFreed(), false, | ||
|
@@ -75,9 +75,15 @@ void BasicEngine::Init(VarBase* var, bool retain_graph) { | |
<< " as stop_gradient false"; | ||
var->GradVarBase()->InnerSetOverridedStopGradient(false); | ||
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place()); | ||
grad_var->Resize(fwd_var.dims()); | ||
grad_var->mutable_data(fwd_var.place(), fwd_var.type()); | ||
operators::math::set_constant(*dev_ctx, grad_var, 1.0); | ||
if (grad_tensor == nullptr) { | ||
grad_var->Resize(fwd_var.dims()); | ||
grad_var->mutable_data(fwd_var.place(), fwd_var.type()); | ||
operators::math::set_constant(*dev_ctx, grad_var, 1.0); | ||
} else { | ||
paddle::framework::TensorCopy( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是否需要check,grad_tensor的维度和var的维度是否一致呢? |
||
grad_tensor->Var().Get<framework::LoDTensor>(), fwd_var.place(), | ||
*dev_ctx, grad_var); | ||
} | ||
} | ||
|
||
void BasicEngine::CheckBackwardInputs(const OpBase& op) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,7 +133,7 @@ def set_value(self, value): | |
framework._current_expected_place()) | ||
|
||
@framework.dygraph_only | ||
def backward(self, retain_graph=False): | ||
def backward(self, grad_tensor=None, retain_graph=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里只能处理一个tensor吧。如果要处理多个grad tensor呢? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里处理多个grad tensor是刚需吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以循环处理吗?单独的backward接口使用很低频 |
||
""" | ||
Run backward of current Graph which starts from current Tensor. | ||
|
||
|
@@ -142,17 +142,22 @@ def backward(self, retain_graph=False): | |
You can clear gradient by ``Tensor.clear_grad()`` . | ||
|
||
Args: | ||
grad_tensor(Tensor, optional): initial gradient values of the current Tensor. If `grad_tensor` is None, | ||
the initial gradient values of the current Tensor would be Tensor filled with 1.0; | ||
if `grad_tensor` is not None, it must have the same length as the current Tensor. | ||
Teh default value is None. | ||
|
||
retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would | ||
like to add more ops to the built graph after calling this method( :code:`backward` ), set the parameter | ||
:code:`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient. | ||
Defaults to False. | ||
|
||
Returns: | ||
NoneType: None | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
import paddle | ||
x = paddle.to_tensor(5., stop_gradient=False) | ||
for i in range(5): | ||
y = paddle.pow(x, 4.0) | ||
|
@@ -168,15 +173,34 @@ def backward(self, retain_graph=False): | |
print("{}".format(x.grad)) | ||
# 0. | ||
|
||
grad_tensor=paddle.to_tensor(2.) | ||
for i in range(5): | ||
y = paddle.pow(x, 4.0) | ||
y.backward(grad_tensor) | ||
print("{}: {}".format(i, x.grad)) | ||
# 0: [1000.] | ||
# 1: [2000.] | ||
# 2: [3000.] | ||
# 3: [4000.] | ||
# 4: [5000.] | ||
|
||
""" | ||
if framework.in_dygraph_mode(): | ||
if grad_tensor is not None: | ||
assert isinstance( | ||
grad_tensor, core. | ||
VarBase), "The type of grad_tensot must be paddle.VarBase" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
assert grad_tensor.shape == self.shape, "Variable shape not match, Variable of grad_tensor [ {} ] with shape {} mismatch Variable [ {} ] with shape {}".format( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
grad_tensor.name, grad_tensor.shape, self.name, self.shape) | ||
|
||
if paddle.is_compiled_with_xpu(): | ||
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future. | ||
scaled_loss = scale_loss(self) | ||
scaled_loss._run_backward(framework._dygraph_tracer(), | ||
retain_graph) | ||
retain_graph, grad_tensor) | ||
else: | ||
self._run_backward(framework._dygraph_tracer(), retain_graph) | ||
self._run_backward(framework._dygraph_tracer(), retain_graph, | ||
grad_tensor) | ||
else: | ||
raise ValueError( | ||
"Variable.backward() is only available in DyGraph mode") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import unittest | ||
import numpy as np | ||
|
||
import paddle | ||
import paddle.fluid.dygraph as dg | ||
from op_test import OpTest | ||
|
||
|
||
class TestBackward(unittest.TestCase): | ||
def setUp(self): | ||
self._dtypes = ["float32", "float64"] | ||
self._places = [paddle.CPUPlace()] | ||
if paddle.is_compiled_with_cuda(): | ||
self._places.append(paddle.CUDAPlace(0)) | ||
|
||
def test_all_positive(self): | ||
for dtype in self._dtypes: | ||
x = np.random.random([2, 100]).astype(dtype) | ||
y = np.random.random([100, 2]).astype(dtype) | ||
z = np.matmul(x, y) | ||
grad = np.random.random(z.shape).astype(dtype) | ||
for place in self._places: | ||
with dg.guard(place): | ||
x_tensor = paddle.to_tensor(x, stop_gradient=False) | ||
y_tensor = paddle.to_tensor(y) | ||
z_tensor = paddle.matmul(x_tensor, y_tensor) | ||
|
||
grad_tensor = paddle.to_tensor(grad) | ||
z_tensor.backward(grad_tensor) | ||
|
||
x_grad = np.matmul(grad, y.T) | ||
|
||
self.assertTrue(np.allclose(x_grad, x_tensor.grad)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以把 grad_tensor 设置为默认从参数nullptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
声明处默认参数为nullptr