Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] relay.transform.AnnotateSpans does not work on modules imported from Pytorch #8994

Closed
gilles-reservoir opened this issue Sep 13, 2021 · 4 comments · Fixed by #9015
Closed
Assignees

Comments

@gilles-reservoir
Copy link

gilles-reservoir commented Sep 13, 2021

Expected behavior

AnnotateSpans should work on modules imported from Pytorch.

Actual behavior

It fails, because the PyTorch importer creates variables with "." in their names. AnnotateSpans works by pretty-printing and then re-parsing a module. Unfortunately, the parser (tokenizer?) chokes on variables with "." in the name.

This makes it very difficult to debug other transformations when working with PyTorch models.

Environment

Ubuntu LTS, latest TVM main, torch 1.8.0

Steps to reproduce

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense = nn.Linear(4, 13)
        
    def forward(self, x):
        return self.dense(x)

net = Net()
data = torch.randn(17, 4)
traced = torch.jit.trace(net, data).eval()

mod, params = relay.frontend.from_pytorch(traced, [('x', tuple(data.shape))])
print(mod['main'].astext())
mod = relay.transform.AnnotateSpans()(mod)

Output:

#[version = "0.0.5"]
fn (%x: Tensor[(17, 4), float32], %dense.weight: Tensor[(13, 4), float32], %dense.bias: Tensor[(13), float32]) {
  %0 = transpose(%dense.weight, axes=[1, 0]);
  %1 = transpose(%0, axes=[1, 0]);
  %2 = nn.dense(%x, %1, units=13);
  add(%2, %dense.bias)
}
error: expected a local variable found `.`
 --> GeneratedSource:115:48
     |  
 115 |  def @main(%x: Tensor[(17, 4), float32], %dense.weight: Tensor[(13, 4), float32], %dense.bias: Tensor[(13), float32]) {
     |                                                 ^                                                                      

Unfortunately it's not totally trivial to fix this because tuples use "." for element access. Maybe the PyTorch importer / all variable creation should excise "."s from variable names, but I could see that breaking scripts that currently work.

@masahi masahi self-assigned this Sep 13, 2021
@masahi
Copy link
Member

masahi commented Sep 13, 2021

Thanks I'll take a look.

@masahi
Copy link
Member

masahi commented Sep 14, 2021

There is a comment suggesting "x.y" should be parsed as an identifier ideally:

// Right now we fail to parse `x.y`.

So we should fix the tokenizer / parser

@masahi
Copy link
Member

masahi commented Sep 14, 2021

Oh in PT models we can have a variable name like %features.1.conv.1_bias. We cannot distinguish '1' there with tuple access...

I'll write a pass to replace "." with "_" in a variable name, and let users optionally run it after PT model import.

@gilles-reservoir
Copy link
Author

Thanks so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants