Skip to content

Commit

Permalink
Fix pytorch frontend prim::Constant issue (apache#6051)
Browse files Browse the repository at this point in the history
  • Loading branch information
jxx123 authored and trevor-m committed Jul 14, 2020
1 parent 0950905 commit 038dfc8
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,8 @@ def _get_constant(node):
return node.f(attr_name)
elif ty in ["TensorType", "CompleteTensorType"]:
tensor = node.t(attr_name)
if tensor.is_cuda:
tensor = tensor.cpu()
if len(tensor.shape) == 0: # tensor(0.1)
# TODO(t-vi): When is this needed?
return tensor.item()
Expand Down

0 comments on commit 038dfc8

Please sign in to comment.