Skip to content

Commit

Permalink
[CI] Torch 1.7 update to mainline (apache#6828)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and trevor-m committed Dec 4, 2020
1 parent 97ea828 commit 91ff038
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

// NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. -->
ci_lint = "tlcpack/ci-lint:v0.62"
ci_gpu = "tlcpack/ci-gpu:v0.71"
ci_gpu = "tlcpack/ci-gpu:v0.72"
ci_cpu = "tlcpack/ci-cpu:v0.71"
ci_wasm = "tlcpack/ci-wasm:v0.70"
ci_i386 = "tlcpack/ci-i386:v0.71"
Expand Down
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ pip3 install onnxruntime==1.0.0
# not expose that in the wheel!!!
pip3 install future

pip3 install torch==1.4.0 torchvision==0.5.0
pip3 install torch==1.7.0 torchvision==0.8.1
60 changes: 35 additions & 25 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import itertools
import logging
import sys
import math

import numpy as np

Expand Down Expand Up @@ -168,7 +169,6 @@ def _min():

def _unary(name):
def _impl(inputs, input_types):
input_type = input_types[0]
# this is just to ensure tensor input
(data,) = _pytorch_promote_types(inputs[:1], input_types[:1])
return get_relay_op(name)(data)
Expand Down Expand Up @@ -1552,7 +1552,7 @@ def _impl(inputs, input_types):
axis = None
keepdims = False
if len(inputs) > 2:
axis = inputs[1]
axis = inputs[1] if len(inputs[1]) > 0 else None
keepdims = bool(inputs[2])

return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims))
Expand Down Expand Up @@ -1847,18 +1847,33 @@ def _impl(inputs, input_types):
return _impl


def _upsample(method, prelude):
def _impl(inputs, input_types):
out_size = []
def _get_upsample_out_size(inputs, method):
# This assumes a static shape
out_size = []
if inputs[1] is not None:
for size in inputs[1]:
if not isinstance(size, int):
out_size.append(int(_infer_value(size, {}).asnumpy()))
else:
out_size.append(size)
else:
scale_index = 3 if method in ["bilinear", "trilinear"] else 2
scales = inputs[scale_index]
assert scales is not None, "neither out size nor scale provided"
assert isinstance(scales, list)
ishape = _infer_shape(inputs[0])
for i, scale in enumerate(scales):
out_size.append(int(math.floor(float(ishape[2 + i]) * scale)))

return out_size


def _upsample(method, prelude):
def _impl(inputs, input_types):
data = inputs[0]
out_size = _get_upsample_out_size(inputs, method)

if len(inputs) > 2:
if len(inputs) > 2 and method == "bilinear":
align_corners = inputs[2]
else:
align_corners = False
Expand All @@ -1874,35 +1889,24 @@ def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)

if _is_quantized_tensor(data, prelude):
# Torch version > 1.4 changed upsampling API
if is_version_greater_than("1.4.0"):
num_inputs = 7
else:
num_inputs = 5

assert len(inputs) == num_inputs, "Input quant param not found in op inputs"

# input qparams are manually appended by us
assert isinstance(inputs[-2], float)
assert isinstance(inputs[-1], int)
input_scale = _expr.const(inputs[-2])
input_zero_point = _expr.const(inputs[-1])
return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func)

return func(data)

return _impl


def _upsample3d(method):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
elif _is_int_seq(inputs[1]):
out_size = inputs[1]
elif isinstance(inputs[1], list):
infer_res = [_infer_value(size, {}) for size in inputs[1]]
out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]

data = inputs[0]
out_size = _get_upsample_out_size(inputs, method)

if len(inputs) > 2:
if len(inputs) > 2 and method == "trilinear":
align_corners = inputs[2]
else:
align_corners = False
Expand Down Expand Up @@ -1983,8 +1987,7 @@ def _impl(inputs, input_types):

def _logical_not():
def _impl(inputs, input_types):
data = inputs[0]

data = _wrap_const(inputs[0])
return _op.logical_not(_op.cast(data, "bool"))

return _impl
Expand Down Expand Up @@ -2732,6 +2735,7 @@ def _get_convert_map(prelude, default_dtype):
"aten::empty": _empty(),
"aten::bincount": _bincount(),
"aten::scatter_add": _scatter_add(),
"aten::__not__": _logical_not(),
}
return convert_map

Expand Down Expand Up @@ -2798,6 +2802,7 @@ def _report_missing_conversion(op_names, convert_map):
"prim::ListUnpack",
"prim::TupleConstruct",
"prim::TupleUnpack",
"prim::RaiseException",
"prim::If",
"prim::Loop",
]
Expand Down Expand Up @@ -2903,6 +2908,8 @@ def _get_operator_nodes(nodes):
ops = []
# Traverse nodes and add to graph
for node in nodes:
if node.outputsSize() == 0:
continue
if node.outputsSize() > 1:
node_name = "_".join(_get_output_names(node))
else:
Expand Down Expand Up @@ -3286,6 +3293,9 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude, defau
else:
unpacked = _unpack_tuple(inputs[0])
outputs.update(zip(_get_output_names(op_node), unpacked))
elif operator == "prim::prim::RaiseException":
logging.warning("raising exceptions is ignored")
outputs[node_name] = None
elif operator == "prim::If":
if_out = convert_if(op_node, outputs, convert_map, prelude, default_dtype=default_dtype)
outputs[node_name] = if_out
Expand Down
3 changes: 2 additions & 1 deletion tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ def get_imagenet_input():
# disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
# See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
# ("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
("googlenet", qgooglenet(pretrained=True), per_channel),
# tracing quantized googlenet broken as of v1.6
# ("googlenet", qgooglenet(pretrained=True), per_channel),
]

results = []
Expand Down
4 changes: 2 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2535,7 +2535,7 @@ def test_forward_linspace():

class Linspace1(Module):
def forward(self, *args):
return torch.linspace(5, 10)
return torch.linspace(5, 10, steps=100)

class Linspace2(Module):
def forward(self, *args):
Expand All @@ -2559,7 +2559,7 @@ def forward(self, *args):

class Linspace7(Module):
def forward(self, *args):
return torch.linspace(1, 4, dtype=torch.float32)
return torch.linspace(1, 4, steps=100, dtype=torch.float32)

class Linspace8(Module):
def forward(self, *args):
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_auto_scheduler_layout_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,6 @@ def test_correctness_layout_rewrite_insert_transform_stage():

if __name__ == "__main__":
test_apply_steps_with_layout_rewrite()
test_correctness_layout_rewrite_rewrite_for_preTransformed()
test_correctness_layout_rewrite_insert_transform_stage()
# Disable for now due to being flaky on i386
# test_correctness_layout_rewrite_rewrite_for_preTransformed()
# test_correctness_layout_rewrite_insert_transform_stage()
6 changes: 3 additions & 3 deletions tutorials/frontend/deploy_object_detection_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@
.. code-block:: bash
pip install torch==1.4.0
pip install torchvision==0.5.0
pip install torch==1.7.0
pip install torchvision==0.8.1
or please refer to official site
https://pytorch.org/get-started/locally/
PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version.
Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may
be unstable.
"""

Expand Down
6 changes: 3 additions & 3 deletions tutorials/frontend/from_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@
.. code-block:: bash
pip install torch==1.4.0
pip install torchvision==0.5.0
pip install torch==1.7.0
pip install torchvision==0.8.1
or please refer to official site
https://pytorch.org/get-started/locally/
PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version.
Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may
be unstable.
"""

Expand Down

0 comments on commit 91ff038

Please sign in to comment.