Skip to content

Commit

Permalink
[Frontend][Pytorch] add suppport for 'aten::upsample_bicubic2d' (apac…
Browse files Browse the repository at this point in the history
…he#8648)

* fix

* lint
  • Loading branch information
hgt312 authored and ylc committed Jan 13, 2022
1 parent 2252df2 commit e0f0f3a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
11 changes: 7 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,7 +1808,7 @@ def get_upsample_out_size(self, inputs, method):
else:
out_size.append(size)
else:
scale_index = 3 if method == "linear" else 2
scale_index = 3 if method != "nearest_neighbor" else 2
scales = inputs[scale_index]
assert scales is not None, "neither out size nor scale provided"
assert isinstance(scales, list)
Expand All @@ -1823,7 +1823,7 @@ def upsample(inputs, input_types):
data = inputs[0]
out_size = self.get_upsample_out_size(inputs, method)

if len(inputs) > 2 and method == "linear":
if len(inputs) > 2 and method != "nearest_neighbor":
align_corners = inputs[2]
else:
align_corners = False
Expand All @@ -1836,7 +1836,9 @@ def upsample(inputs, input_types):
coord_trans = "half_pixel"

def func(x):
return _op.image.resize2d(x, out_size, "NCHW", method, coord_trans)
return _op.image.resize2d(
x, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75
)

if self.is_quantized_tensor(data):
# input qparams are manually appended by us
Expand Down Expand Up @@ -2212,7 +2214,7 @@ def interpolate(self, inputs, input_types):
else:
coord_trans = "half_pixel"

return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans)
return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75)

def numel(self, inputs, input_types):
return _op.ndarray_size(inputs[0])
Expand Down Expand Up @@ -2780,6 +2782,7 @@ def create_convert_map(self):
"aten::clamp_": self.clamp,
"aten::detach": self.identity,
"aten::upsample_bilinear2d": self.make_upsample("linear"),
"aten::upsample_bicubic2d": self.make_upsample("cubic"),
"aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"),
"aten::upsample_trilinear3d": self.make_upsample3d("linear"),
"aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"),
Expand Down
3 changes: 3 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,9 @@ def forward(self, x):
verify_model(Upsample(size=(64, 64), mode="bilinear", align_corners=True), inp)
verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp)
verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp)
verify_model(Upsample(size=(64, 64), mode="bicubic", align_corners=True), inp)
verify_model(Upsample(scale=2, mode="bicubic", align_corners=True), inp)
verify_model(Upsample(size=(50, 50), mode="bicubic", align_corners=True), inp)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit e0f0f3a

Please sign in to comment.