Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PYTORCH]Repeat, Reciprocal & Reshape Op support #5280

Merged
merged 1 commit into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,34 @@ def _impl(inputs, input_types):
return _op.transform.take(data, index, axis=dim)
return _impl

def _reciprocal():
def _impl(inputs, input_types):
data = inputs[0]
return _expr.const(1.0) / data
return _impl

def _repeat():
def _impl(inputs, input_types):
data = inputs[0]
reps = _get_dims(inputs[1])
return _op.transform.tile(data, reps=reps)
return _impl

def _repeat_interleave():
def _impl(inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], int):
repeats = inputs[1]
axis = inputs[2]
else:
msg = "Only repeat with one value as repeat is currently supported."
raise AssertionError(msg)
if axis is None: # Flatten the data if no axis is given from torch
data = _op.transform.reshape(data, [-1])
axis = 0
return _op.transform.repeat(data, repeats=repeats, axis=axis)
return _impl

def _ones():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -675,6 +703,16 @@ def _impl(inputs, input_types):
return _op.transform.reshape(data, new_shape)
return _impl

def _reshape():
def _impl(inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], list):
new_shape = inputs[1]
else:
new_shape = _infer_shape(inputs[1])
return _op.transform.reshape(data, new_shape)
return _impl

def _clone():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1082,6 +1120,9 @@ def _wrap_const(c):
"aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::reciprocal" : _reciprocal(),
"aten::repeat" : _repeat(),
"aten::repeat_interleave" : _repeat_interleave(),
"aten::to" : _to(),
"aten::squeeze" : _squeeze(),
"aten::unsqueeze" : _unsqueeze(),
Expand Down Expand Up @@ -1122,6 +1163,7 @@ def _wrap_const(c):
"aten::addmm" : _dense(),
"aten::size" : _size(),
"aten::view" : _view(),
"aten::reshape" : _reshape(),
"aten::clone" : _clone(),
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
Expand Down
75 changes: 75 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,61 @@ def forward(self, *args):
verify_model(Multiply3().float().eval(), input_data=input_data)
verify_model(Multiply4().float().eval(), input_data=input_data)

def test_forward_reciprocal():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
class Reciprocal1(Module):
def forward(self, *args):
return args[0].reciprocal()

input_data = torch.rand(input_shape).float()
verify_model(Reciprocal1().float().eval(), input_data=input_data)

def test_forward_repeat():
torch.set_grad_enabled(False)
input_shape = [1, 3]
class Repeat1(Module):
def forward(self, *args):
return args[0].repeat(1, 1)

class Repeat2(Module):
def forward(self, *args):
return args[0].repeat(4, 2)

class Repeat3(Module):
def forward(self, *args):
return args[0].repeat(4, 2, 1)

input_data = torch.rand(input_shape).float()
verify_model(Repeat1().float().eval(), input_data=input_data)
verify_model(Repeat2().float().eval(), input_data=input_data)
verify_model(Repeat3().float().eval(), input_data=input_data)

def test_forward_repeat_interleave():
torch.set_grad_enabled(False)
input_shape = [2, 2, 3]
class RepeatInterleave1(Module):
def forward(self, *args):
return args[0].repeat_interleave(2)

class RepeatInterleave2(Module):
def forward(self, *args):
return args[0].repeat_interleave(3, dim=0)

class RepeatInterleave3(Module):
def forward(self, *args):
return args[0].repeat_interleave(2, dim=1)

class RepeatInterleave4(Module):
def forward(self, *args):
return args[0].repeat_interleave(4, dim=2)

input_data = torch.rand(input_shape).float()
verify_model(RepeatInterleave1().float().eval(), input_data=input_data)
verify_model(RepeatInterleave2().float().eval(), input_data=input_data)
verify_model(RepeatInterleave3().float().eval(), input_data=input_data)
verify_model(RepeatInterleave4().float().eval(), input_data=input_data)

def test_forward_unsqueeze():
torch.set_grad_enabled(False)
input_shape = [10, 10]
Expand Down Expand Up @@ -600,6 +655,22 @@ def init_weight(m):
init_weight(ln.eval())
verify_model(ln.eval(), input_data=inp)

def test_forward_reshape():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
new_shape = [2, 1, 10, 10]
class Reshape1(Module):
def forward(self, *args):
return args[0].reshape(new_shape)

class Reshape2(Module):
def forward(self, *args):
return args[0].reshape([-1])

input_data = torch.rand(input_shape).float()
verify_model(Reshape1().float().eval(), input_data=input_data)
verify_model(Reshape2().float().eval(), input_data=input_data)

def test_forward_transpose():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -1151,6 +1222,10 @@ def forward(self, xs):
test_forward_add()
test_forward_subtract()
test_forward_multiply()
test_forward_reshape()
test_forward_reciprocal()
test_forward_repeat()
test_forward_repeat_interleave()
test_forward_squeeze()
test_forward_unsqueeze()
test_forward_concatenate()
Expand Down