diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 536e4145db292..f985a9010961e 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -959,6 +959,30 @@ struct LayerNormAttrs : public tvm::AttrsNode { }; // struct LayerNormAttrs +/*! \brief Attributes used in group_norm operator */ +struct GroupNormAttrs : public tvm::AttrsNode { + int num_groups; + int axis; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(GroupNormAttrs, "relay.attrs.GroupNormAttrs") { + TVM_ATTR_FIELD(num_groups).set_default(0) + .describe("Specify number of groups to separate the channels into."); + TVM_ATTR_FIELD(axis).set_default(1) + .describe("Specify which shape axis denotes the channel."); + TVM_ATTR_FIELD(epsilon).set_default(1e-5) + .describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).set_default(true) + .describe("If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true) + .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + } +}; // struct GroupNormAttrs + + /*! \brief Attributes for LRN operator */ struct LRNAttrs : public tvm::AttrsNode { int size; diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ed31d34c0661c..9da3ecfc54ffc 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -831,6 +831,26 @@ def _impl(inputs, input_types): scale=True) return _impl + +def _group_norm(): + def _impl(inputs, input_types): + data = inputs[0] + gamma = inputs[2] + beta = inputs[3] + num_groups = inputs[1] + epsilon = float(inputs[4]) + + return _op.nn.group_norm(data, + gamma=gamma, + beta=beta, + num_groups=num_groups, + axis=1, + epsilon=epsilon, + center=True, + scale=True) + return _impl + + def _transpose(prelude): def _impl(inputs, input_types): data = inputs[0] @@ -1630,6 +1650,7 @@ def _get_convert_map(prelude): "aten::batch_norm" : _batch_norm(), "aten::instance_norm" : _instance_norm(), "aten::layer_norm" : _layer_norm(), + "aten::group_norm" : _group_norm(), "aten::transpose" : _transpose(prelude), "aten::transpose_" : _transpose(prelude), "aten::t" : _transpose(prelude), diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index d0a81bccd085b..622b0faaccea2 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1708,6 +1708,75 @@ def layer_norm(data, return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale) +def group_norm(data, + gamma, + beta, + num_groups, + axis=1, + epsilon=1e-5, + center=True, + scale=True): + r""" + Group normalization normalizes over group of channels for each training examples. + We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put + all the channels into a single group, group normalization becomes Layer normalization. + And, when we put each channel into different groups it becomes Instance normalization + + https://arxiv.org/pdf/1803.08494.pdf + + Applies group normalization to the n-dimensional input array by seperating the input channels + into 'num_groups' groups, each containing 'num_channels / num_groups' channels. + The mean and standard-deviation are calculated separately over the each group. gamma and + beta are learnable per-channel affine transform parameter vectors of size num_channels. + + .. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}} + * gamma + beta + + Unlike batch normalization, the mean and var are computed along a group of channels. + + If the input has size k on axis 1, then both gamma and beta have shape (k,). + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : tvm.relay.Expr + Input to which group_norm will be applied. + + gamma : tvm.relay.Expr + The gamma scale factor. + + beta : tvm.relay.Expr + The beta offset factor. + + num_groups : int + The number of groups to separate the channels into. + + axis : int, optional, default=1 + The axis of the channels. + + epsilon : double, optional, default=1e-5 + Small float added to variance to avoid dividing by zero. + + center : boolean, optional, default=True + If True, add offset of beta to normalized tensor, If False, + beta is ignored. + + scale : boolean, optional, default=True + If True, multiply by gamma. If False, gamma is not used. + + Returns + ------- + result : tvm.relay.Expr + The normalized data. + """ + return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale) + + def batch_matmul(x, y): r""" Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index b9ba74f9e95d1..5cdca8011aa2d 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -852,6 +852,80 @@ RELAY_REGISTER_OP("nn.layer_norm") .set_support_level(1) .add_type_rel("LayerNorm", LayerNormRel); +// group_norm +TVM_REGISTER_NODE_TYPE(GroupNormAttrs); + +bool GroupNormRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + if (data == nullptr) return false; + const GroupNormAttrs* param = attrs.as(); + int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size(); + CHECK(axis >= 0 && axis < (int)data->shape.size()); + reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); + reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype)); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); + + return true; +} + +Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, + int axis, double epsilon, bool center, bool scale) { + auto attrs = make_object(); + attrs->num_groups = num_groups; + attrs->axis = axis; + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + static const Op& op = Op::Get("nn.group_norm"); + return Call(op, {data, gamma, beta}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeGroupNorm, args, rv); + }); + +RELAY_REGISTER_OP("nn.group_norm") +.describe(R"code( +Group normalization normalizes over group of channels for each training examples. +We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put +all the channels into a single group, group normalization becomes Layer normalization. +And, when we put each channel into different groups it becomes Instance normalization + +https://arxiv.org/pdf/1803.08494.pdf + +Applies group normalization to the n-dimensional input array by seperating the input channels +into 'num_groups' groups, each containing 'num_channels / num_groups' channels. +The mean and standard-deviation are calculated separately over the each group. gamma and +beta are learnable per-channel affine transform parameter vectors of size num_channels. + +.. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}} + * gamma + beta + +Unlike batch normalization, the mean and var are computed along a group of channels. + +If the input has size k on axis 1, then both gamma and beta have shape (k,). + +.. note:: + + This operator can be optimized away for inference. + +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_num_inputs(3) +.add_argument("data", "Tensor", "Input to which group_norm will be applied.") +.add_argument("gamma", "Tensor", "The gamma scale factor.") +.add_argument("beta", "Tensor", "The beta offset factor.") +.set_support_level(1) +.add_type_rel("GroupNorm", GroupNormRel); + + // relay.nn.batch_matmul bool BatchMatmulRel(const Array& types, int num_inputs, diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index d349fdddeeea8..a9ceec26ce06c 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -64,6 +64,66 @@ Expr BatchNormToInferUnpack(const Attrs attrs, return out; } + +Expr GroupNormToInferUnpack(const Attrs attrs, + Expr data, + Expr gamma, + Expr beta, + Type tdata) { + auto ttype = tdata.as(); + CHECK(ttype); + const auto param = attrs.as(); + CHECK(param); + + int ndim = ttype->shape.size(); + int axis = (param->axis < 0) ? param->axis + ndim : param->axis; + Array reduced_axes; + Array new_shape; + Array old_shape; + + int num_groups = param->num_groups; + int channel = ttype->shape[axis].as()->value; + + // old_shape = N, C, H, W + // new shape = N, num_groups, C/num_groups, H, W + // reduce_axes = axis of (C/num_groups, H, W) + for (int i = 0; i < ndim; ++i) { + auto val = ttype->shape[i].as()->value; + + // Save the old shape to reshape later + old_shape.push_back(val); + if (i == axis) { + new_shape.push_back(num_groups); + new_shape.push_back(channel / num_groups); + reduced_axes.push_back(i + 1); + continue; + } + if (i >= axis) { + reduced_axes.push_back(i + 1); + } + new_shape.push_back(val); + } + + data = Reshape(data, new_shape); + + Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon)); + Expr mean = Mean(data, {reduced_axes}, true, false); + Expr var = Variance(data, mean, {reduced_axes}, true, false); + Expr denom = Sqrt(Add(var, epsilon)); + Expr out = Divide(Subtract(data, mean), denom); + + out = Reshape(out, old_shape); + + if (param->scale) { + out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); + } + if (param->center) { + out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); + } + + return out; +} + Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, @@ -143,6 +203,7 @@ class InferenceSimplifier : public ExprMutator { dropout_op_(Op::Get("nn.dropout")), instance_norm_op_(Op::Get("nn.instance_norm")), layer_norm_op_(Op::Get("nn.layer_norm")), + group_norm_op_(Op::Get("nn.group_norm")), l2_norm_op_(Op::Get("nn.l2_normalize")) {} Expr VisitExpr_(const TupleGetItemNode* n) final { @@ -170,6 +231,10 @@ class InferenceSimplifier : public ExprMutator { const auto* call = new_n.as(); return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], n->args[0]->checked_type()); + } else if (n->op == group_norm_op_) { + const auto* call = new_n.as(); + return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], + n->args[0]->checked_type()); } else if (n->op == instance_norm_op_) { const auto* call = new_n.as(); return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], @@ -189,6 +254,7 @@ class InferenceSimplifier : public ExprMutator { const Op& dropout_op_; const Op& instance_norm_op_; const Op& layer_norm_op_; + const Op& group_norm_op_; const Op& l2_norm_op_; std::unordered_map ty_map_; }; diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 573fa7e3b29dc..c692c5ef35d2c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -717,6 +717,28 @@ def init_weight(m): init_weight(ln.eval()) verify_model(ln.eval(), input_data=inp) + +def test_forward_groupnorm(): + input_shape = [10, 6, 5, 5] + input_data = torch.rand(input_shape).float() + + # Separate 6 channels into 3 groups + verify_model(torch.nn.GroupNorm(3, 6).eval(), input_data=input_data) + + # Put all 6 channels into a single group (equivalent with LayerNorm) + verify_model(torch.nn.GroupNorm(1, 6).eval(), input_data=input_data) + + # Separate 6 channels into 6 groups (equivalent with InstanceNorm) + verify_model(torch.nn.GroupNorm(6, 6).eval(), input_data=input_data) + + input_shape = [1, 10, 4, 7] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.GroupNorm(1, 10).eval(), input_data=input_data) + verify_model(torch.nn.GroupNorm(2, 10).eval(), input_data=input_data) + verify_model(torch.nn.GroupNorm(5, 10).eval(), input_data=input_data) + verify_model(torch.nn.GroupNorm(10, 10).eval(), input_data=input_data) + + def test_forward_reshape(): torch.set_grad_enabled(False) input_shape = [2, 1, 10, 1, 10] @@ -1865,6 +1887,7 @@ def forward(self, *args): test_forward_batchnorm() test_forward_instancenorm() test_forward_layernorm() + test_forward_groupnorm() test_forward_transpose() test_forward_size() test_forward_view()