From 14157dd243f3e16ef411210da5dc4561467984fd Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 11 Jun 2020 14:10:27 +0200 Subject: [PATCH] Fix gelu in PyTorch frontend, tighten numerical checks (#5763) Previously, the PyTorch frontend approximated gelu with fastgelu. To provide a more faithful conversion, we implement gelu instead. We also tighten the numerical comparisons between PyTorch and TVM-from-PyTorch to 1e-5. The object detection models need an increased tolerance of 1e-4 to pass. I had to throw in a few fixes for missing conversions (probably due to working with very new PyTorch). I must admit the GoogLeNet/NasNet test didn't run on my machine, probably due to problems at my end. --- python/tvm/relay/frontend/pytorch.py | 20 ++++++++------ tests/python/frontend/pytorch/test_forward.py | 27 ++++++++++--------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7b96530dc703..380388a3df58 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -481,7 +481,10 @@ def _impl(inputs, input_types): msg = "Data type %s could not be parsed in zeros op" % (type(data)) raise AssertionError(msg) - dtype = _convert_data_type(_convert_dtype_value(inputs[2])) + if inputs[2] is not None: # dtype given + dtype = _convert_data_type(_convert_dtype_value(inputs[2])) + else: + dtype = data.type_annotation.dtype return _op.full(_expr.const(fill_value), shape, dtype=dtype) return _impl @@ -567,14 +570,13 @@ def _impl(inputs, input_types): def _gelu(): def _impl(inputs, input_types): - import math data = inputs[0] - - def _pow3(x): - return x * x * x - return _expr.const(0.5) * data * (_expr.const(1.0) + - _op.tanh(_expr.const(math.sqrt(2.0 / math.pi)) * - (data + _expr.const(0.044715) * _pow3(data)))) + # gelu is data * normcdf(data) + # normcdf expressed as erf because we don't currently have that intrinsic + # note that there is also a fastgelu variant approximating normcdf + # with tanh and third order polynomials, but this is "true" gelu + return data * (_expr.const(0.5) + + _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5)) return _impl def _selu(): @@ -1839,6 +1841,7 @@ def _get_convert_map(prelude): "aten::Int" : _int(), "prim::NumToTensor" : _numtotensor(), "prim::ImplicitTensorToNum" : _tensortonum(), + "aten::ScalarImplicit" : _tensortonum(), "aten::constant_pad_nd" : _pad("constant"), "aten::reflection_pad1d" : _pad("reflect"), "aten::reflection_pad2d" : _pad("reflect"), @@ -1877,6 +1880,7 @@ def _get_convert_map(prelude): "aten::floor" : _unary("floor"), "aten::round" : _unary("round"), "aten::isfinite" : _unary("isfinite"), + "aten::isinf" : _unary("isinf"), "aten::isnan" : _unary("isnan"), "aten::clamp" : _clamp(), "aten::detach" : _identity(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3c7ff4fbbecb..c9c76be47baa 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -135,7 +135,8 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40): def verify_model(model_name, input_data=[], custom_convert_map={}, - ctx_list=ctx_list()): + ctx_list=ctx_list(), + rtol=1e-5, atol=1e-5): """Assert that the output of a compiled model matches with that of its baseline.""" if isinstance(model_name, str): @@ -190,7 +191,7 @@ def verify_model(model_name, input_data=[], assert_shapes_match(baseline_output, compiled_output) tvm.testing.assert_allclose(baseline_output, compiled_output, - rtol=1e-3, atol=1e-3) + rtol=rtol, atol=atol) del model_name del baseline_model @@ -1216,35 +1217,35 @@ def test_conv3d_transpose(): # Model tests def test_resnet18(): torch.set_grad_enabled(False) - verify_model("resnet18") + verify_model("resnet18", atol=1e-4, rtol=1e-4) def test_squeezenet1_0(): torch.set_grad_enabled(False) - verify_model("squeezenet1_0") + verify_model("squeezenet1_0", atol=1e-4, rtol=1e-4) def test_squeezenet1_1(): torch.set_grad_enabled(False) - verify_model("squeezenet1_1") + verify_model("squeezenet1_1", atol=1e-4, rtol=1e-4) def test_densenet121(): torch.set_grad_enabled(False) - verify_model("densenet121") + verify_model("densenet121", atol=1e-4, rtol=1e-4) def test_inception_v3(): torch.set_grad_enabled(False) - verify_model("inception_v3") + verify_model("inception_v3", atol=1e-4, rtol=1e-4) def test_googlenet(): torch.set_grad_enabled(False) - verify_model("googlenet") + verify_model("googlenet", atol=1e-4, rtol=1e-4) def test_mnasnet0_5(): torch.set_grad_enabled(False) - verify_model("mnasnet0_5") + verify_model("mnasnet0_5", atol=1e-4, rtol=1e-4) def test_mobilenet_v2(): torch.set_grad_enabled(False) - verify_model("mobilenet_v2") + verify_model("mobilenet_v2", atol=1e-4, rtol=1e-4) """ #TODO: Fix VGG and AlexNet issues (probably due to pooling) @@ -1305,19 +1306,19 @@ def forward(self, inp): inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)] - verify_model(SegmentationModelWrapper(fcn.eval()), inp) + verify_model(SegmentationModelWrapper(fcn.eval()), inp, atol=1e-4, rtol=1e-4) # depthwise + dilated covolution not supported on x86 # see https://github.com/apache/incubator-tvm/issues/4962 cuda_ctx = ("cuda", tvm.gpu(0)) if cuda_ctx[1].exist: - verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx]) + verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx], atol=1e-4, rtol=1e-4) def test_3d_models(): input_shape = (1, 3, 4, 56, 56) resnet3d = torchvision.models.video.r3d_18(pretrained=True).eval() - verify_model(resnet3d, [torch.rand(input_shape)]) + verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4) def verify_script_model(pt_model, ishapes):