Skip to content

Commit

Permalink
[torch] Add linear operator support (apache#7569)
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored and Trevor Morris committed May 6, 2021
1 parent d1009a2 commit 960f43a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,6 +1374,20 @@ def avg_pool3d(self, inputs, input_types):
count_include_pad=count_include_pad,
)

def linear(self, inputs, input_types):
# https://pytorch.org/docs/stable/nn.functional.html#linear
# 0 - input
# 1 - weight
bias = inputs[2]
mm_out = self.matmul(inputs[:2], input_types[:2])
if isinstance(bias, _expr.Expr):
bias_ndims = len(self.infer_shape_with_prelude(bias))
if bias_ndims == 1:
return _op.nn.bias_add(mm_out, bias)
mm_dtype = self.infer_type_with_prelude(mm_out).dtype
return self.add([mm_out, bias], [mm_dtype, input_types[2]])
return mm_out

def dropout(self, inputs, input_types):
data = inputs[0]
rate = float(inputs[1])
Expand Down Expand Up @@ -2289,6 +2303,7 @@ def create_convert_map(self):
"aten::softplus": self.softplus,
"aten::avg_pool2d": self.avg_pool2d,
"aten::avg_pool3d": self.avg_pool3d,
"aten::linear": self.linear,
"aten::dropout": self.dropout,
"aten::dropout_": self.dropout,
"aten::feature_dropout": self.dropout,
Expand Down
34 changes: 34 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import torchvision
from torch.nn import Module
from torch.nn import functional as F
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
Expand Down Expand Up @@ -1459,6 +1460,39 @@ def forward(self, *args):
assert not any([op.name == "multiply" for op in list_ops(mod["main"])])


@tvm.testing.uses_gpu
def test_forward_linear():
torch.set_grad_enabled(False)

class Linear(Module):
def forward(self, input, weight, bias):
return F.linear(input, weight, bias)

class LinearNoBias(Module):
def forward(self, input, weight):
return F.linear(input, weight)

input2d = torch.rand([2, 2]).float()
weight1d = torch.rand([2]).float()
weight2d = torch.rand([2, 2]).float()
bias1d = torch.rand([2]).float()
bias2d = torch.rand([2, 2]).float()
# 2D input, 2D weight, 1D bias
verify_model(Linear(), input_data=[input2d, weight2d, bias1d])
# 2D input, 2D weight, 2D bias
verify_model(Linear(), input_data=[input2d, weight2d, bias2d])
# 2D input, 2D weight, no bias
verify_model(LinearNoBias(), input_data=[input2d, weight2d])
# 2D input, 1D weight, 1D bias is not supported by torch.linear()
# 2D input, 1D weight, no bias
verify_model(LinearNoBias(), input_data=[input2d, weight1d])
# TODO: Add the following cases when matmul(1D, _) is supported by TVM
# 1D input, 2D weight, 1D bias
# 1D input, 2D weight, no bias
# 1D input, 1D weight, scalar bias
# 1D input, 1D weight, no bias


@tvm.testing.uses_gpu
def test_forward_dropout():
torch.set_grad_enabled(False)
Expand Down

0 comments on commit 960f43a

Please sign in to comment.