Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ResNetUnit Python API #35426

Merged
merged 60 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
57a6edc
first commit
ZzSean Aug 17, 2021
ae796a4
Support full compile
ZzSean Sep 3, 2021
889d430
add mean and var
ZzSean Sep 22, 2021
01841a6
modify based on 'resnet_unit_op' branch
ZzSean Sep 22, 2021
c65b5df
modify bn test
ZzSean Sep 23, 2021
af42837
sync cudnn_bn_add_relu_test
ZzSean Sep 23, 2021
a59f3f3
support bn+add+relu backward and compile success
ZzSean Sep 24, 2021
8b5ca97
add OpMaker and InferShape for backward
ZzSean Sep 24, 2021
b519f8c
add conv backward
ZzSean Sep 26, 2021
b0165ee
fix run error
ZzSean Sep 26, 2021
da04c42
add param dtype check
ZzSean Sep 26, 2021
3aaeb44
fix run error
ZzSean Sep 26, 2021
450c80c
fix
ZzSean Sep 26, 2021
236a1f6
fix bwd error
ZzSean Sep 26, 2021
adba900
decrease intermediate space and fix bad malloc
ZzSean Sep 27, 2021
ccb13a5
Enhance the cache of make fused cudnn ops.
Xreki Sep 27, 2021
ed67cf0
Merge branch 'develop' into cudnn_v8_fusion
Xreki Sep 27, 2021
c82017b
Fix a bug of std::vector.
Xreki Sep 27, 2021
5339c4e
Use independent forward and backward cache.
Xreki Sep 27, 2021
1a1c3bf
Enable inplace addto for resnet_unit_grad op.
Xreki Sep 28, 2021
688a779
python test
ZzSean Sep 28, 2021
686452f
Merge branch 'develop' into cudnn_v8_fusion
Xreki Sep 28, 2021
97e56eb
Merge branch 'develop' into cudnn_v8_fusion
Xreki Sep 29, 2021
34864b1
Merge branch 'cudnn_v8_fusion' of https://github.com/ZzSean/Paddle in…
Xreki Sep 29, 2021
097187b
change bn_add_relu backward implementation
ZzSean Oct 8, 2021
ac0da0e
Merge branch 'develop' into cudnn_v8_fusion
Xreki Oct 8, 2021
d19d0a5
Merge branch 'cudnn_v8_fusion' of https://github.com/ZzSean/Paddle in…
Xreki Oct 8, 2021
959b429
refactor the cudnn fusion op
ZzSean Oct 8, 2021
a267dbb
Polish the codes of cudnn_bn_add_relu_test.
Xreki Oct 8, 2021
dfc3c40
Merge branch 'cudnn_v8_fusion' of https://github.com/ZzSean/Paddle in…
Xreki Oct 8, 2021
2e77e80
add backward test for bn_add_relu
ZzSean Oct 8, 2021
1c06aa9
Add more test case.
Xreki Oct 8, 2021
4a52f07
Merge branch 'cudnn_v8_fusion' of https://github.com/ZzSean/Paddle in…
Xreki Oct 8, 2021
f50f291
delete debug info
ZzSean Oct 9, 2021
41a4a84
Polish the codes of resnet_unit_op and fix two bugs.
Xreki Oct 9, 2021
c7bf91b
Merge branch 'cudnn_v8_fusion' of https://github.com/ZzSean/Paddle in…
Xreki Oct 9, 2021
70b3f2d
add more test
ZzSean Oct 9, 2021
b3dd740
Polish and remove some unused codes.
Xreki Oct 9, 2021
e2c0172
Merge branch 'cudnn_v8_fusion' of https://github.com/ZzSean/Paddle in…
Xreki Oct 9, 2021
0d8eb92
Merge branch 'develop' into cudnn_v8_fusion
Xreki Oct 9, 2021
d0dee01
add attr is_test and delete unused code
ZzSean Oct 11, 2021
7bd83d4
change ptr to tensor in forward
ZzSean Oct 11, 2021
794565d
change ptr to tensor in backward
ZzSean Oct 11, 2021
f8a4847
update develop
ZzSean Oct 12, 2021
b35afee
Merge branch 'develop' into cudnn_v8_fusion
ZzSean Oct 12, 2021
13efb64
unify the format of the params of fusion inferences
ZzSean Oct 12, 2021
ab369a1
add more assert
ZzSean Oct 12, 2021
fce8a19
fix
ZzSean Oct 12, 2021
2ec78a4
delete unused code
ZzSean Oct 12, 2021
aa4cc57
add notes
ZzSean Oct 12, 2021
212ab3f
fix some potential bug
ZzSean Oct 12, 2021
0d5bdae
Merge branch 'develop' into cudnn_v8_fusion
ZzSean Oct 14, 2021
fd4e10b
modify
ZzSean Oct 14, 2021
33e261b
Revert "modify"
ZzSean Oct 14, 2021
a2b1044
Revert "Revert "modify""
ZzSean Oct 14, 2021
57289d7
fix bugs
ZzSean Oct 14, 2021
64968bf
Merge branch 'develop' into cudnn_v8_fusion
ZzSean Oct 14, 2021
c66c4b2
support inplace addto and add python API
ZzSean Oct 14, 2021
e755d16
delete assert
ZzSean Oct 14, 2021
d43a950
fix
ZzSean Oct 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ void InplaceAddToOpPass::Run(Graph *graph) const {
out_var_ptr->GeneratedOp());

// NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy
if (right_generated_op->Name() != "conv2d_grad") {
if (right_generated_op->Name() != "conv2d_grad" &&
right_generated_op->Name() != "resnet_unit_grad") {
continue;
}

Expand Down Expand Up @@ -224,11 +225,13 @@ static bool IsValidConv2DGradDataGradNode(const Node &node) {
if (node.inputs.empty()) return false;
auto *generated_op = node.inputs[0];
auto *op_desc = generated_op->Op();
if (op_desc == nullptr || op_desc->Type() != "conv2d_grad") {
if (op_desc == nullptr || (op_desc->Type() != "conv2d_grad" &&
op_desc->Type() != "resnet_unit_grad")) {
return false;
}
const auto &outputs = op_desc->Outputs();
auto iter = outputs.find(GradVarName("Input"));
std::string grad_var_name = op_desc->Type() == "conv2d_grad" ? "Input" : "X";
auto iter = outputs.find(GradVarName(grad_var_name));
return iter != outputs.end() && !iter->second.empty() &&
iter->second[0] == node.Name() &&
!op_desc->GetAttrIfExists<bool>("use_addto");
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/fused/resnet_unit_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,14 @@ class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<bool>("use_addto", "").SetDefault(false);
AddAttr<std::string>("act_type", "The activation type to be fused.")
.SetDefault("relu");
AddComment(R"DOC(
Fusion op of the basic unit of resnet block.
Fusion op of the basic unit of resnet block.

The implementation is based on the latest fusion op interface in cuDNN v8.0.
For more details:
For more details:
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnFusedOps_t

)DOC");
Expand Down
19 changes: 10 additions & 9 deletions paddle/fluid/operators/fused/resnet_unit_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
Expand Down Expand Up @@ -87,7 +87,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_x.Resize(param_dims);
sum_of_squares_x.Resize(param_dims);
CudnnNormConvolution<T> conv_x_op(dev_ctx, input_x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
output_shape, padding, stride, dilation,
group);
conv_x_op.Forward(dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x,
&sum_of_squares_x);
Expand Down Expand Up @@ -129,8 +129,8 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_z.Resize(param_dims);
sum_of_squares_z.Resize(param_dims);
CudnnNormConvolution<T> conv_z_op(dev_ctx, input_z_shape, filter_z_shape,
output_shape, padding, stride_z, dilate,
group);
output_shape, padding, stride_z,
dilation, group);
conv_z_op.Forward(dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z,
&sum_of_squares_z);

Expand Down Expand Up @@ -189,7 +189,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
Expand Down Expand Up @@ -263,7 +263,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
auto filter_z_shape = framework::vectorize<int>(filter_z->dims());
CudnnNormConvolutionGrad<T> conv_z_op(dev_ctx, z_shape, filter_z_shape,
output_shape, padding, stride_z,
dilate, group);
dilation, group);
conv_z_op.Backward(dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad,
filter_z_grad);
} else {
Expand All @@ -278,11 +278,12 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
}

// 2. Backward of Conv for x, get x_grad and filter_x_grad
bool use_addto = ctx.Attr<bool>("use_addto");
CudnnNormConvolutionGrad<T> conv_x_op(dev_ctx, x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
group);
output_shape, padding, stride,
dilation, group);
conv_x_op.Backward(dev_ctx, *x, *filter_x, conv_out_x_grad, x_grad,
filter_x_grad);
filter_x_grad, use_addto);
}
};

Expand Down
1 change: 1 addition & 0 deletions python/paddle/incubate/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle # noqa: F401
from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401
from .resnet_unit import ResNetUnit #noqa: F401
269 changes: 269 additions & 0 deletions python/paddle/incubate/operators/resnet_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import collections
import itertools
import six
import math
import sys
import warnings
from functools import partial, reduce

import numpy as np
import paddle
import paddle.fluid as fluid
from paddle import framework
from paddle.device import get_device, get_cudnn_version
from paddle.nn import initializer as I
from paddle.nn import Layer, LayerList
from paddle.fluid.layers import utils
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.param_attr import ParamAttr
from paddle import _C_ops
__all__ = ['resnet_unit', 'ResNetUnit']


def resnet_unit(x, filter_x, scale_x, bias_x, mean_x, var_x, z, filter_z,
scale_z, bias_z, mean_z, var_z, stride, stride_z, padding,
dilation, groups, momentum, eps, data_format, fuse_add,
has_shortcut, use_global_stats, is_test, act):

helper = LayerHelper('resnet_unit', **locals())
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bit_mask_dtype = fluid.core.VarDesc.VarType.INT32
out = helper.create_variable_for_type_inference(x.dtype)
bit_mask = helper.create_variable_for_type_inference(
dtype=bit_mask_dtype, stop_gradient=True)
# intermediate_out for x
conv_x = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
saved_mean_x = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd_x = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean_x = mean_x
running_var_x = var_x
# intermediate_out for z
conv_z = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
saved_mean_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if mean_z is None else mean_z
running_var_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if var_z is None else var_z

inputs = {
'X': x,
'FilterX': filter_x,
'ScaleX': scale_x,
'BiasX': bias_x,
'MeanX': mean_x,
'VarX': var_x,
'Z': z,
'FilterZ': filter_z,
'ScaleZ': scale_z,
'BiasZ': bias_z,
'MeanZ': mean_z,
'VarZ': var_z
}

attrs = {
'stride': stride,
'stride_z': stride_z,
'padding': padding,
'dilation': dilation,
'group': groups,
'momentum': momentum,
'epsilon': eps,
'data_format': data_format,
'fuse_add': fuse_add,
'has_shortcut': has_shortcut,
'use_global_stats': use_global_stats,
'is_test': is_test,
'act_type': act
}

outputs = {
'Y': out,
'BitMask': bit_mask,
'ConvX': conv_x,
'SavedMeanX': saved_mean_x,
'SavedInvstdX': saved_invstd_x,
'RunningMeanX': running_mean_x,
'RunningVarX': running_var_x,
'ConvZ': conv_z,
'SavedMeanZ': saved_mean_z,
'SavedInvstdZ': saved_invstd_z,
'RunningMeanZ': running_mean_z,
'RunningVarZ': running_var_z,
}

helper.append_op(
type='resnet_unit', inputs=inputs, outputs=outputs, attrs=attrs)

return out


class ResNetUnit(Layer):
r"""
******Temporary version******.
ResNetUnit is designed for optimize the performence by using cudnnv8 API.
"""

def __init__(self,
num_channels_x,
num_filters,
filter_size,
stride=1,
momentum=0.9,
eps=1e-5,
data_format='NHWC',
act='relu',
fuse_add=False,
has_shortcut=False,
use_global_stats=False,
is_test=False,
filter_x_attr=None,
scale_x_attr=None,
bias_x_attr=None,
moving_mean_x_name=None,
moving_var_x_name=None,
num_channels_z=1,
stride_z=1,
filter_z_attr=None,
scale_z_attr=None,
bias_z_attr=None,
moving_mean_z_name=None,
moving_var_z_name=None):
super(ResNetUnit, self).__init__()
self._stride = stride
self._stride_z = stride_z
self._dilation = 1
self._kernel_size = utils.convert_to_list(filter_size, 2, 'kernel_size')
self._padding = (filter_size - 1) // 2
self._groups = 1
self._momentum = momentum
self._eps = eps
self._data_format = data_format
self._act = act
self._fuse_add = fuse_add
self._has_shortcut = has_shortcut
self._use_global_stats = use_global_stats
self._is_test = is_test

# check format
valid_format = {'NHWC'}
if data_format not in valid_format:
raise ValueError(
"conv_format must be one of {}, but got conv_format='{}'".
format(valid_format, data_format))

def _get_default_param_initializer(channels):
filter_elem_num = np.prod(self._kernel_size) * channels
std = (2.0 / filter_elem_num)**0.5
return I.Normal(0.0, std)

# initial filter
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bn_param_shape = [1, 1, 1, num_filters]
filter_x_shape = [num_filters, filter_size, filter_size, num_channels_x]
filter_z_shape = [num_filters, filter_size, filter_size, num_channels_z]

self.filter_x = self.create_parameter(
shape=filter_x_shape,
attr=filter_x_attr,
default_initializer=_get_default_param_initializer(num_channels_x))
self.scale_x = self.create_parameter(
shape=bn_param_shape,
attr=scale_x_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_x = self.create_parameter(
shape=bn_param_shape,
attr=bias_x_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_x = self.create_parameter(
attr=ParamAttr(
name=moving_mean_x_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.mean_x.stop_gradient = True
self.var_x = self.create_parameter(
attr=ParamAttr(
name=moving_var_x_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.var_x.stop_gradient = True
if has_shortcut:
self.filter_z = self.create_parameter(
shape=filter_z_shape,
attr=filter_z_attr,
default_initializer=_get_default_param_initializer(
num_channels_z))
self.scale_z = self.create_parameter(
shape=bn_param_shape,
attr=scale_z_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_z = self.create_parameter(
shape=bn_param_shape,
attr=bias_z_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_z = self.create_parameter(
attr=ParamAttr(
name=moving_mean_z_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.mean_z.stop_gradient = True
self.var_z = self.create_parameter(
attr=ParamAttr(
name=moving_var_z_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.var_z.stop_gradient = True
else:
self.filter_z = None
self.scale_z = None
self.bias_z = None
self.mean_z = None
self.var_z = None

def forward(self, x, z=None):
if self._fuse_add and z is None:
raise ValueError("z can not be None")

out = resnet_unit(
x, self.filter_x, self.scale_x, self.bias_x, self.mean_x,
self.var_x, z, self.filter_z, self.scale_z, self.bias_z,
self.mean_z, self.var_z, self._stride, self._stride_z,
self._padding, self._dilation, self._groups, self._momentum,
self._eps, self._data_format, self._fuse_add, self._has_shortcut,
self._use_global_stats, self._is_test, self._act)
return out