Skip to content

Commit

Permalink
[Unity] [Bugfix] Fix TypeError in TVM PyTorch frontend for LayerNorm …
Browse files Browse the repository at this point in the history
…operator (#15902)
  • Loading branch information
Thrsu authored Oct 10, 2023
1 parent ec1184e commit b138005
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
12 changes: 9 additions & 3 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,9 +975,15 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var:
# functional.layer_norm
if node.target not in self.named_modules:
# static or symbolic
normalized_shape = (
node.args[1] if type(node.args[1]) == tuple else self.env[node.args[1]]
)
arg = node.args[1]
if isinstance(arg, tuple):
value = arg
else:
try:
value = self.env[arg]
except TypeError:
value = tuple(arg)
normalized_shape = value
dim_num = len(normalized_shape)
axes = list(range(-dim_num, 0))

Expand Down
40 changes: 40 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,46 @@ def main(
binding = {}
verify_model(model, input_info, binding, expected2)

class LayerNorm3(Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
self.weight = torch.nn.Parameter(torch.ones(shape))
self.bias = torch.nn.Parameter(torch.zeros(shape))

def forward(self, input):
return torch.nn.functional.layer_norm(input, self.shape, self.weight, self.bias, 1e-5)

@tvm.script.ir_module
class expected3:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
w1: R.Tensor([10, 10], dtype="float32"),
w2: R.Tensor([10, 10], dtype="float32"),
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm(
input_1,
w1,
w2,
axes=[-2, -1],
epsilon=1e-05,
center=True,
scale=True,
)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

model = LayerNorm3([10, 10])
binding = {
"w1": model.weight.detach().numpy(),
"w2": model.bias.detach().numpy(),
}
verify_model(model, input_info, binding, expected3)


def test_cross_entropy():
input_info = [([3, 2], "float32"), ([3], "int32")]
Expand Down

0 comments on commit b138005

Please sign in to comment.