You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm submitting this bug by request from @pkgoogle where we found it in #62275 .
It seems that the converter is doing two unintended actions when handling LSTMs:
LSTM is being decomposed instead of being converted into "UnidirectionalSequenceLSTM" operator. The latter is the default behavior in Tensorflow.
The LSTM is being unrolled without the user's consent. In Tensorflow, one of its arguments allow for unrolling but by default it's off (obviously). Now, if the first bug is fixed, this bug might no longer be relevant unless the user wishes to manipulate the hidden states as that will force the compiler, at least in tensorflow, to switch to the decomposed operators defined inside a "While" loop as I described in #62775. In this case, it should be up to the user to choose to use a loop or the unrolled version.
Actual vs expected behavior:
You can find the test code below.
Actual Behavior: With ai-edge-torch, we can clearly see the LSTM is being decomposed and unrolled:
Expected Behavior: With tensorflow:
Any other information you'd like to share?
Torch code:
import torch
from torch import nn
import ai_edge_torch
class SimpleModel(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size)
self.d1 = nn.Linear(hidden_size, 1)
def forward(self, x):
x, (h0, c0) = self.lstm(x)
x = self.d1(x)
return x
model = SimpleModel(256, 64)
sample_inputs = (torch.randn(16, 43, 256),)
edge_model = ai_edge_torch.convert(model.eval(), sample_inputs)
edge_model.export("simple_model.tflite")
Python version: 3.11.9
ai_edge_torch version: 0.1.1 (installed in a fresh conda environment by following the instructions mentioned in the release section.
Operating system: Ubuntu 22.04.3 LTS in WSL2
The text was updated successfully, but these errors were encountered:
Description of the bug:
I'm submitting this bug by request from @pkgoogle where we found it in #62275 .
It seems that the converter is doing two unintended actions when handling LSTMs:
Actual vs expected behavior:
You can find the test code below.
Actual Behavior: With ai-edge-torch, we can clearly see the LSTM is being decomposed and unrolled:
Expected Behavior: With tensorflow:
Any other information you'd like to share?
Torch code:
Tensorflow code:
Python version: 3.11.9
ai_edge_torch version: 0.1.1 (installed in a fresh conda environment by following the instructions mentioned in the release section.
Operating system: Ubuntu 22.04.3 LTS in WSL2
The text was updated successfully, but these errors were encountered: