diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c82b487acff3..8e626f52d528 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -46,6 +46,14 @@ __all__ = ["from_pytorch"] +def _is_version_greater_than(ver): + import torch + from packaging import version + + # Torch version > 1.4 changed upsampling API + return version.parse(torch.__version__) > version.parse(ver) + + # List ADT utilities def _infer_type_with_prelude(val, prelude): body = _infer_type(val, prelude.mod) @@ -413,13 +421,18 @@ def _impl(inputs, input_types): def _split_with_sizes(): def _impl(inputs, input_types): data = inputs[0] + sections = inputs[1] dim = int(inputs[2]) + if len(sections) == 1: + # a special case used in torchvision detection models + return _expr.TupleWrapper(_expr.Tuple([data]), 1) + split_index = 0 indices = [] - sections = inputs[1] for i in range(len(sections) - 1): - split_index += sections[i] + index, _ = try_infer_value(sections[i], lambda ret: int(ret)) + split_index += index indices.append(split_index) return _op.split(data, indices, dim) @@ -522,6 +535,9 @@ def _impl(inputs, input_types): def _where(): def _impl(inputs, input_types): + if len(inputs) == 1: + return _nonzero(False)([inputs[0], True], input_types) + cond = inputs[0] x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3]) return _op.where(cond, x, y) @@ -1865,11 +1881,8 @@ def func(x): return _op.image.resize(x, out_size, "NCHW", method, coord_trans) if _is_quantized_tensor(data, prelude): - import torch - from packaging import version - # Torch version > 1.4 changed upsampling API - if version.parse(torch.__version__) > version.parse("1.4.0"): + if _is_version_greater_than("1.4.0"): num_inputs = 7 else: num_inputs = 5 @@ -2172,9 +2185,11 @@ def _impl(inputs, input_types): data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) # strided slice to get the dynamic result - return get_relay_op("strided_slice")( + ret = get_relay_op("strided_slice")( data_slice, begin=_expr.const([0]), end=size, slice_mode="size" ) + # in torchvision, indices from nms are int64 + return _op.cast(ret, "int64") return _impl @@ -2266,9 +2281,8 @@ def _impl(inputs, input_types): ret = _op.transform.argwhere(data) if is_numpy_style or (len(inputs) > 1 and inputs[1]): - # TODO(kevinthesun): Support this by adding unbind op - # ret = _unbind()([ret, 0], None) - raise RuntimeError("as_tuple is not supported yet for nonzero.") + return _unbind()([ret, 1], None) + return ret return _impl @@ -2335,6 +2349,21 @@ def _impl(inputs, input_types): return _impl +def _numel(): + def _impl(inputs, input_types): + return _op.ndarray_size(inputs[0]) + + return _impl + + +def _empty(): + def _impl(inputs, input_types): + shape = inputs[0] + return _op.zeros(shape, _convert_dtype_value(inputs[1])) + + return _impl + + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -2673,6 +2702,10 @@ def _get_convert_map(prelude, default_dtype): "aten::scatter": _scatter(), "aten::scalar_tensor": _scalar_tensor(), "aten::__interpolate": _interpolate(), + "aten::IntImplicit": _identity(), + "aten::tensor": _identity(), # used for example in tensor(1.0) + "aten::numel": _numel(), + "aten::empty": _empty(), } return convert_map @@ -2681,7 +2714,13 @@ def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ import torch - torch._C._jit_pass_inline(graph) + if _is_version_greater_than("1.5.0"): + # This is required for torchvision detection models from 1.6 above + # It is the same as _jit_pass_inline, except that it has some special + # case behaviors for some ops such as aten::__interpolate() + torch._C._jit_pass_onnx_function_substitution(graph) + else: + torch._C._jit_pass_inline(graph) def _get_tensor_and_var(torch_tensor, name): diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index ddbad12c0b53..40eba949b0ec 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -179,6 +179,8 @@ def no_data_full_shape_func(attrs, inputs, out_ndims): """ Shape func for zeros and ones. """ + if len(inputs) == 0: + return [_convert_shape(convert(attrs.shape))] return [_full_shape_func(inputs[0])] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8c1143646426..54c3daf25385 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2865,10 +2865,19 @@ class Where2(Module): def forward(self, *args): return torch.where(args[0] > 0, args[0], args[1]) + class Where3(Module): + def forward(self, *args): + return torch.where(args[0])[0] + x = torch.rand([3, 2]).float() - verify_model(Where1().float().eval(), input_data=[x]) + verify_model(Where1(), input_data=[x]) y = torch.rand([3, 2]) - verify_model(Where2().float().eval(), input_data=[x, y]) + verify_model(Where2(), input_data=[x, y]) + + # a single argument variant, equivalent to torch.nonzero(..., as_tuple=True) + inp = torch.rand([10]) + inp[3:8] = 0 + verify_trace_model(Where3(), [inp], ["llvm"]) @tvm.testing.uses_gpu @@ -3152,6 +3161,17 @@ def forward(self, data, index, src): verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"]) +def test_numel(): + class Numel(Module): + def forward(self, data): + return torch.tensor(torch.numel(data)) + + targets = _get_default_vm_targets() + verify_script_model(Numel(), [(1,)], targets) + verify_script_model(Numel(), [(3, 5)], targets) + verify_script_model(Numel(), [(3, 5, 8)], targets) + + def test_forward_pretrained_bert_base_uncased(): ###################################################################### # This is an example how to run BERT models using TVM @@ -3455,6 +3475,7 @@ def expected(x_shape, y_shape): test_forward_unbind() test_forward_nonzero() test_forward_scatter() + test_numel() # Model tests test_resnet18()