diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f2e24a128b95..da8c0944799a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1369,7 +1369,7 @@ def _impl(inputs, input_types): return None return _impl -def _pad(): +def _pad(mode): def _impl(inputs, input_types): data = inputs[0] if isinstance(inputs[1], list): @@ -1394,9 +1394,11 @@ def _impl(inputs, input_types): # group into tuple of 2 ints paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)] - pad_value = inputs[2] + if mode == "constant": + return _op.nn.pad(data, paddings, pad_value=inputs[2], pad_mode=mode) + else: + return _op.nn.pad(data, paddings, pad_mode=mode) - return _op.nn.pad(data, paddings, pad_value) return _impl @@ -1654,22 +1656,6 @@ def _impl(inputs, input_types): return _impl -def _reflection_pad2d(): - def _impl(inputs, input_types): - if isinstance(inputs[1], list): - pad_list = inputs[1] - else: - pad_list = list(_infer_shape(inputs[1])) - padding_left = pad_list[0] - padding_right = pad_list[1] - padding_top = pad_list[2] - padding_bottom = pad_list[3] - paddings = [[0, 0], [0, 0], [padding_top, padding_bottom], [padding_left, padding_right]] - - return _op.nn.mirror_pad(inputs[0], paddings, mode='REFLECT') - return _impl - - # Helper functions for operator implementation def _convert_dtype_value(val): convert_torch_dtype_map = {7:"torch.float64", @@ -1836,7 +1822,12 @@ def _get_convert_map(prelude): "aten::Int" : _int(), "prim::NumToTensor" : _numtotensor(), "prim::ImplicitTensorToNum" : _tensortonum(), - "aten::constant_pad_nd" : _pad(), + "aten::constant_pad_nd" : _pad("constant"), + "aten::reflection_pad1d" : _pad("reflect"), + "aten::reflection_pad2d" : _pad("reflect"), + "aten::replication_pad1d" : _pad("edge"), + "aten::replication_pad2d" : _pad("edge"), + "aten::replication_pad3d" : _pad("edge"), "aten::permute" : _transpose(prelude), "aten::sum" : _reduce("sum"), "aten::prod" : _reduce("prod"), @@ -1895,7 +1886,6 @@ def _get_convert_map(prelude): "aten::embedding" : _embedding(), "aten::one_hot" : _one_hot(), "aten::mm" : _matmul(prelude), - "aten::reflection_pad2d" : _reflection_pad2d(), "relay::tensor_array_stack" : _tensor_array_stack(prelude), "aten::add" : _add(prelude), "aten::add_" : _add(prelude), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f6edbf119684..e41da7ecd799 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1116,6 +1116,15 @@ def test_forward_constant_pad3d(): verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp) +def test_forward_reflection_pad1d(): + inp = torch.rand((1, 2, 4)) + verify_model(torch.nn.ReflectionPad1d(2).eval(), inp) + verify_model(torch.nn.ReflectionPad1d((3, 1)).eval(), inp) + + inp = torch.rand((2, 4, 5)) + verify_model(torch.nn.ReflectionPad1d((2, 3)).eval(), inp) + + def test_forward_reflection_pad2d(): inp = torch.rand((1, 1, 3, 3)) verify_model(torch.nn.ReflectionPad2d(2).eval(), inp) @@ -1125,6 +1134,33 @@ def test_forward_reflection_pad2d(): verify_model(torch.nn.ReflectionPad2d((1, 3, 2, 4)).eval(), inp) +def test_forward_replication_pad1d(): + inp = torch.rand((1, 2, 4)) + verify_model(torch.nn.ReplicationPad1d(2).eval(), inp) + verify_model(torch.nn.ReplicationPad1d((3, 1)).eval(), inp) + + inp = torch.rand((2, 4, 5)) + verify_model(torch.nn.ReplicationPad1d((2, 3)).eval(), inp) + + +def test_forward_replication_pad2d(): + inp = torch.rand((1, 1, 3, 3)) + verify_model(torch.nn.ReplicationPad2d(2).eval(), inp) + verify_model(torch.nn.ReplicationPad2d((1, 1, 2, 0)).eval(), inp) + + inp = torch.rand((2, 4, 5, 6)) + verify_model(torch.nn.ReplicationPad2d((1, 3, 2, 4)).eval(), inp) + + +def test_forward_replication_pad3d(): + inp = torch.rand((1, 1, 3, 3, 3)) + verify_model(torch.nn.ReplicationPad3d(3).eval(), inp) + verify_model(torch.nn.ReplicationPad3d((1, 1, 2, 2, 1, 1)).eval(), inp) + + inp = torch.rand((7, 5, 4, 5, 6)) + verify_model(torch.nn.ReplicationPad3d((2, 3, 2, 5, 1, 4)).eval(), inp) + + def test_forward_upsample3d(): inp = torch.arange(1, 9, dtype=torch.float32).view(1, 1, 2, 2, 2) verify_model(torch.nn.Upsample(scale_factor=2, mode='nearest').eval(), inp) @@ -2429,7 +2465,11 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_constant_pad1d() test_forward_constant_pad2d() test_forward_constant_pad3d() + test_forward_reflection_pad1d() test_forward_reflection_pad2d() + test_forward_replication_pad1d() + test_forward_replication_pad2d() + test_forward_replication_pad3d() test_adaptive_pool3d() test_conv3d()