Skip to content

Commit

Permalink
[Torch] Fix cast to long (#6301)
Browse files Browse the repository at this point in the history
* [Torch] fix cast to long

* retrigger
  • Loading branch information
masahi authored Aug 19, 2020
1 parent 7aa2de3 commit 939a42b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
7 changes: 5 additions & 2 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .. import op as _op
from ..ty import TupleType, TensorType, Any
from ..loops import while_loop
from .. import transform
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
Expand Down Expand Up @@ -1507,14 +1508,16 @@ def _impl(inputs, input_types):
cast_func = {
6: float,
3: int,
4: int
}
cast_func_expr = {
6: lambda x: _op.cast(x, "float32"),
3: lambda x: _op.cast(x, "int32"),
4: lambda x: _op.cast(x, "int64"),
}
if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
return cast_func[inputs[1]](data)
elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
elif inputs[1] in cast_func_expr and isinstance(data, _expr.Expr):
return cast_func_expr[inputs[1]](data)
return data

Expand Down Expand Up @@ -2668,4 +2671,4 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_d

mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])

return mod, tvm_params
return transform.RemoveUnusedFunctions()(mod), tvm_params
16 changes: 6 additions & 10 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,31 +1296,27 @@ def forward(self, x):
def test_to():
""" test for aten::to(...) """
class ToCPU(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.to("cpu")

class ToFloat(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.float()

class ToInt(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.int()

class ToLong(Module):
def forward(self, x):
return x.long()

verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32)))
verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int))
verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int))
verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32)))
verify_model(ToInt().eval(), torch.tensor(2.0))
verify_model(ToInt().eval(), torch.tensor(0.8))
verify_model(ToLong().eval(), torch.tensor(0.8))


def test_adaptive_pool3d():
Expand Down

0 comments on commit 939a42b

Please sign in to comment.