Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSTMs are being unrolled and decomposed after conversion #53

Open
Doomski99 opened this issue Jun 13, 2024 · 0 comments
Open

LSTMs are being unrolled and decomposed after conversion #53

Doomski99 opened this issue Jun 13, 2024 · 0 comments
Assignees

Comments

@Doomski99
Copy link

Description of the bug:

I'm submitting this bug by request from @pkgoogle where we found it in #62275 .

It seems that the converter is doing two unintended actions when handling LSTMs:

  1. LSTM is being decomposed instead of being converted into "UnidirectionalSequenceLSTM" operator. The latter is the default behavior in Tensorflow.
  2. The LSTM is being unrolled without the user's consent. In Tensorflow, one of its arguments allow for unrolling but by default it's off (obviously). Now, if the first bug is fixed, this bug might no longer be relevant unless the user wishes to manipulate the hidden states as that will force the compiler, at least in tensorflow, to switch to the decomposed operators defined inside a "While" loop as I described in #62775. In this case, it should be up to the user to choose to use a loop or the unrolled version.

Actual vs expected behavior:

You can find the test code below.

Actual Behavior: With ai-edge-torch, we can clearly see the LSTM is being decomposed and unrolled:
image

Expected Behavior: With tensorflow:
image

Any other information you'd like to share?

Torch code:

import torch
from torch import nn
import ai_edge_torch


class SimpleModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
    
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.d1 = nn.Linear(hidden_size, 1)

    def forward(self, x):
      
        x, (h0, c0) = self.lstm(x)
        x = self.d1(x)
    
        return x


model = SimpleModel(256, 64)
sample_inputs = (torch.randn(16, 43, 256),)

edge_model = ai_edge_torch.convert(model.eval(), sample_inputs)
edge_model.export("simple_model.tflite")

Tensorflow code:

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, LSTM

model_name = "1x_LSTM_64_float32"
input_length = 256

class SimpleModel(Model):
  def __init__(self, input_shape, hidden_size):
    super().__init__()
    
    self.lstm = LSTM(hidden_size, return_sequences = True, return_state=True, input_shape = [-1, input_shape] )
    
    self.d1 = Dense(1, input_shape = [-1, hidden_size])

  def call(self, x):
      
    x, h0, c0 = self.lstm(x)
    x = self.d1(x)
    
    return x

model = SimpleModel(input_length, 64)

out, states = model(tf.random.uniform([16,43,256]))

print(np.mean(out))

model_path = f"{model_name}.tf"

run_model = tf.function(lambda x: model(x))
BATCH_SIZE = 16
STEPS = 43
INPUT_SIZE = 256
concrete_func = run_model.get_concrete_function(
    tf.TensorSpec([BATCH_SIZE, STEPS, INPUT_SIZE], tf.float32))

model.save(model_path, save_format = 'tf', signatures=concrete_func)

converter = tf.lite.TFLiteConverter.from_saved_model(model_path)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS,
]

tflite_model = converter.convert()
open(f"{model_name}.tflite", "wb").write(tflite_model)

Python version: 3.11.9
ai_edge_torch version: 0.1.1 (installed in a fresh conda environment by following the instructions mentioned in the release section.
Operating system: Ubuntu 22.04.3 LTS in WSL2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants