diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 380388a3df58..00f7eecc8abf 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1184,6 +1184,44 @@ def _impl(inputs, input_types): return _impl +def _norm(): + def _impl(inputs, input_types): + data = inputs[0] + axis = None + keepdims = False + if len(inputs) > 3: + axis = list(_infer_shape(inputs[2])) + keepdims = bool(inputs[3]) + + order = inputs[1] + if order == np.inf: + return _op.reduce.max(_op.abs(data), axis=axis, keepdims=keepdims) + elif order == np.NINF: + return _op.reduce.min(_op.abs(data), axis=axis, keepdims=keepdims) + else: + reci_order = _expr.const(1.0 / order) + order = _expr.const(order) + return _op.power(_op.reduce.sum(_op.power(_op.abs(data), order), + axis=axis, + keepdims=keepdims), + reci_order) + return _impl + + +def _frobenius_norm(): + def _impl(inputs, input_types): + data = inputs[0] + axis = None + keepdims = False + if len(inputs) > 2: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[2]) + + return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims)) + + return _impl + + def _std(): def _impl(inputs, input_types): data = inputs[0] @@ -1853,6 +1891,8 @@ def _get_convert_map(prelude): "aten::prod" : _reduce("prod"), "aten::argmin" : _reduce("argmin"), "aten::argmax" : _reduce("argmax"), + "aten::norm" : _norm(), + "aten::frobenius_norm" : _frobenius_norm(), "aten::std" : _std(), "aten::var" : _variance(), "aten::abs" : _unary("abs"), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index c9c76be47baa..86fb409d5d26 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -892,6 +892,91 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(LogSoftmax1().float().eval(), input_data=input_data) + +def test_forward_norm(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Norm1(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('inf'), dim=None, keepdim=False) + + class Norm2(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=False) + + class Norm3(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=True) + + class Norm4(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('inf'), dim=(1, 2), keepdim=False) + + class Norm5(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('inf'), dim=(1), keepdim=True) + + class Norm6(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(0.5), dim=(1), keepdim=True) + + class Norm7(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(1), dim=None, keepdim=False) + + class Norm8(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(2.0), dim=(1), keepdim=True) + + class Norm9(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(-0.5), dim=(1, 2), keepdim=True) + + class Norm10(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(-2), dim=(1), keepdim=False) + + input_data = torch.rand(input_shape).float() + verify_model(Norm1().float().eval(), input_data=input_data) + verify_model(Norm2().float().eval(), input_data=input_data) + verify_model(Norm3().float().eval(), input_data=input_data) + verify_model(Norm4().float().eval(), input_data=input_data) + verify_model(Norm5().float().eval(), input_data=input_data) + verify_model(Norm6().float().eval(), input_data=input_data) + verify_model(Norm7().float().eval(), input_data=input_data) + verify_model(Norm8().float().eval(), input_data=input_data) + verify_model(Norm9().float().eval(), input_data=input_data) + verify_model(Norm10().float().eval(), input_data=input_data) + + +def test_forward_frobenius_norm(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class FroNorm1(Module): + def forward(self, *args): + return torch.norm(args[0]) + + class FroNorm2(Module): + def forward(self, *args): + return torch.norm(args[0], p='fro', dim=None, keepdim=True) + + class FroNorm3(Module): + def forward(self, *args): + return torch.norm(args[0], p='fro', dim=(1), keepdim=True) + + class FroNorm4(Module): + def forward(self, *args): + return torch.norm(args[0], dim=None, keepdim=False) + + input_data = torch.rand(input_shape).float() + verify_model(FroNorm1().float().eval(), input_data=input_data) + verify_model(FroNorm2().float().eval(), input_data=input_data) + verify_model(FroNorm3().float().eval(), input_data=input_data) + verify_model(FroNorm4().float().eval(), input_data=input_data) + + def test_forward_sigmoid(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -2421,6 +2506,8 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_reduce_prod() test_forward_argmin() test_forward_argmax() + test_forward_norm() + test_forward_frobenius_norm() test_forward_std() test_forward_variance() test_forward_relu()