Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Int64 with Int32 for edge #246

Open
rfechtner opened this issue Sep 21, 2024 · 4 comments
Open

Replace Int64 with Int32 for edge #246

rfechtner opened this issue Sep 21, 2024 · 4 comments
Assignees
Labels
type:support For use-related issues

Comments

@rfechtner
Copy link

rfechtner commented Sep 21, 2024

Description of the bug:

Hi,

I am trying to covert an PyTorch to TFLite which uses torch.argmax(..).indicies and torch.gather(..) - hence creating LongTensors (Int64). As my targeted runtime delegate does not support any int64 ops (including cast int64 -> int32), I am seeking to replace int64 ops by corresponding int32 ones.

Min rep. example:

import ai_edge_torch
import torch

sample_inputs = (torch.randn(1, 3, 224, 224),)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    select = tensor.max(dim=1).indices.unsqueeze(0)
    return torch.gather(tensor, dim=1, index=select)

model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
edge_model(*sample_inputs)

In the past I have been dong this via intermediate ONNX model representation where I modified the relevant nodes and then converted ONNX to TFLite, but with this new framework I’d hoped to get rid of the onnx.

I have tried to replace the torch.argmax() with a tf.math.argmax(.., output_type=tf.int32) or the numpy equivalent which supports specifying the output type or array, but that fails during torch.export() and results in

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py](https://qscyhyv3ft-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in unimplemented(msg, from_exc)
    219     if from_exc is not _NOTHING:
    220         raise Unsupported(msg) from from_exc
--> 221     raise Unsupported(msg)

Unsupported: 'skip function argmax_v2 in file /usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py'

from user code:
   File "<ipython-input-3-7b1c313a80d9>", line 10, in forward
    idx = tf.math.argmax(tensor.detach().numpy(), output_tzpe=tf.int32)

One remaining avenue I can think of is post processing the resulting flatbuffer representation and replacing the int64 ops here, but that seems quite brittle and overly complicated.

Any other suggestions? Or is there a way do dynamically replace functions?

Note: I had to pin tf-nightly==2.18.0.dev20240722 otherwise the export fails with:

Click this to collapse/fold.

---------------------------------------------------------------------------
ConverterError                            Traceback (most recent call last)
[](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in ()
     13 
     14 model = Model().eval()
---> 15 edge_model = ai_edge_torch.convert(model, sample_inputs)
     16 edge_model(*sample_inputs)

12 frames
[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/_convert/converter.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(module, sample_args, sample_kwargs, quant_config, dynamic_shapes, _ai_edge_converter_flags)
    239     _ai_edge_converter_flags = {}
    240 
--> 241   return Converter().convert(
    242       module,
    243       sample_args,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/_convert/converter.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(self, module, sample_args, sample_kwargs, quant_config, dynamic_shapes, _ai_edge_converter_flags)
    161             " specified."
    162         )
--> 163     return conversion.convert_signatures(
    164         self._signatures,
    165         quant_config=quant_config,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/_convert/conversion.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert_signatures(signatures, quant_config, _tfl_converter_flags)
    102   # Apply default fx passes
    103   exported_programs = list(map(_run_convert_passes, exported_programs))
--> 104   tflite_model = lowertools.exported_programs_to_tflite(
    105       exported_programs,
    106       signatures,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/lowertools/_shim.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in exported_programs_to_tflite(exported_programs, signatures, quant_config, _tfl_converter_flags)
     73   )
     74 
---> 75   return utils.merged_bundle_to_tfl_model(
     76       merged_bundle,
     77       signatures,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/lowertools/torch_xla_utils.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in merged_bundle_to_tfl_model(merged_bundle, signatures, quant_config, _tfl_converter_flags)
    271     conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
    272 
--> 273     tflite_model = converter.convert()
    274 
    275     if (

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in wrapper(self, *args, **kwargs)
   1236   def wrapper(self, *args, **kwargs):
   1237     # pylint: disable=protected-access
-> 1238     return self._convert_and_export_metrics(convert_func, *args, **kwargs)
   1239     # pylint: enable=protected-access
   1240 

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in _convert_and_export_metrics(self, convert_func, *args, **kwargs)
   1188     self._save_conversion_params_metric()
   1189     start_time = time.process_time()
-> 1190     result = convert_func(self, *args, **kwargs)
   1191     elapsed_time_ms = (time.process_time() - start_time) * 1000
   1192     if result:

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(self)
   1570     del trackable_obj
   1571     gc.collect()
-> 1572     return self._convert_from_saved_model(graph_def)
   1573 
   1574 

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in _convert_from_saved_model(self, graph_def)
   1428     converter_kwargs.update(quant_mode.converter_flags())
   1429 
-> 1430     result = _convert_saved_model(**converter_kwargs)
   1431     return self._optimize_tflite_model(
   1432         result,

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert_phase.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in wrapper(*args, **kwargs)
    210         else:
    211           report_error_message(str(converter_error))
--> 212         raise converter_error from None  # Re-throws the exception.
    213       except Exception as error:
    214         report_error_message(str(error))

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert_phase.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in wrapper(*args, **kwargs)
    203     def wrapper(*args, **kwargs):
    204       try:
--> 205         return func(*args, **kwargs)
    206       except ConverterError as converter_error:
    207         if converter_error.errors:

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert_saved_model(**kwargs)
   1043   model_flags = build_model_flags(**kwargs)
   1044   conversion_flags = build_conversion_flags(**kwargs)
-> 1045   data = convert(
   1046       model_flags,
   1047       conversion_flags,

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(model_flags, conversion_flags, input_data_str, debug_info_str, enable_mlir_converter)
    374               enable_mlir_converter,
    375           )
--> 376       raise converter_error
    377 
    378   return _run_deprecated_conversion_binary(

ConverterError: Could not translate MLIR to FlatBuffer.:0: error: loc(callsite(callsite(callsite("__main__.Model;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_5"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_11"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): 'vhlo.iota_v1' op is not part of the vhlo support yet.
:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
:0: note: loc(callsite(callsite(callsite("__main__.Model;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_5"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_11"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): see current operation: %12 = "vhlo.iota_v1"() <{iota_dimension = #vhlo.integer_v1<0 : i64>}> : () -> tensor<224xui32>
:0: error: failed while converting: 'main': 
:0: note: see current operation: 
"func.func"() <{arg_attrs = [{tf_saved_model.index_path = ["args_0"]}], function_type = (tensor<1x3x224x224xf32>) -> tensor<1x1x224x224xf32>, res_attrs = [{tf_saved_model.index_path = ["output_0"]}], sym_name = "main"}> ({
^bb0(%arg0: tensor<1x3x224x224xf32>):
  %0 = "arith.constant"() <{value = dense<[1, 1, 224, 224, 1]> : tensor<5xi32>}> : () -> tensor<5xi32>
  %1 = "arith.constant"() <{value = dense<[1, 1, 224, 224]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %2 = "arith.constant"() <{value = dense<[1, 1, 1, 224, 1]> : tensor<5xi32>}> : () -> tensor<5xi32>
  %3 = "arith.constant"() <{value = dense<[1, 1, 224, 1, 1]> : tensor<5xi32>}> : () -> tensor<5xi32>
  %4 = "arith.constant"() <{value = dense<[1, 1, 224, 224, 1]> : tensor<5xi64>}> : () -> tensor<5xi64>
  %5 = "arith.constant"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
  %6 = "arith.constant"() <{value = dense<0> : tensor<1x1x224x224x1xui32>}> : () -> tensor<1x1x224x224x1xui32>
  %7 = "tfl.arg_max"(%arg0, %5) : (tensor<1x3x224x224xf32>, tensor<1xi32>) -> tensor<1x224x224xi32>
  %8 = "tfl.cast"(%7) : (tensor<1x224x224xi32>) -> tensor<1x224x224xi64>
  %9 = "tfl.reshape"(%8, %1) : (tensor<1x224x224xi64>, tensor<4xi32>) -> tensor<1x1x224x224xi64>
  %10 = "tfl.cast"(%9) : (tensor<1x1x224x224xi64>) -> tensor<1x1x224x224xui32>
  %11 = "tfl.reshape"(%10, %0) : (tensor<1x1x224x224xui32>, tensor<5xi32>) -> tensor<1x1x224x224x1xui32>
  %12 = "vhlo.iota_v1"() <{iota_dimension = #vhlo.integer_v1<0 : i64>}> : () -> tensor<224xui32>
  %13 = "tfl.reshape"(%12, %3) : (tensor<224xui32>, tensor<5xi32>) -> tensor<1x1x224x1x1xui32>
  %14 = "tfl.broadcast_to"(%13, %4) : (tensor<1x1x224x1x1xui32>, tensor<5xi64>) -> tensor<1x1x224x224x1xui32>
  %15 = "tfl.reshape"(%12, %2) : (tensor<224xui32>, tensor<5xi32>) -> tensor<1x1x1x224x1xui32>
  %16 = "tfl.broadcast_to"(%15, %4) : (tensor<1x1x1x224x1xui32>, tensor<5xi64>) -> tensor<1x1x224x224x1xui32>
  %17 = "tfl.concatenation"(%6, %11, %14, %16) <{axis = 4 : i32, fused_activation_function = "NONE"}> : (tensor<1x1x224x224x1xui32>, tensor<1x1x224x224x1xui32>, tensor<1x1x224x224x1xui32>, tensor<1x1x224x224x1xui32>) -> tensor<1x1x224x224x4xui32>
  %18 = "tfl.cast"(%17) : (tensor<1x1x224x224x4xui32>) -> tensor<1x1x224x224x4xi64>
  %19 = "tfl.gather_nd"(%arg0, %18) : (tensor<1x3x224x224xf32>, tensor<1x1x224x224x4xi64>) -> tensor<1x1x224x224xf32>
  "func.return"(%19) : (tensor<1x1x224x224xf32>) -> ()
}) {tf.entry_function = {control_outputs = "", inputs = "serving_default_args_0:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} : () -> ()

@rfechtner rfechtner changed the title Prevent / Replace Int64 Replace Int64 with Int32 for ArmNN Sep 21, 2024
@rfechtner rfechtner changed the title Replace Int64 with Int32 for ArmNN Replace Int64 with Int32 for edge Sep 21, 2024
@pkgoogle pkgoogle self-assigned this Sep 23, 2024
@pkgoogle
Copy link
Contributor

pkgoogle commented Sep 23, 2024

Hi @rfechtner, I was actually not able to replicate this issue if I use the latest code in main i.e.:

# navigate to ai-edge-torch repo
git switch main # if not already in the main branch
git pull # update to latest code
pip install -e .
pip install tensorflow-cpu # There was an import conflict that the latest code works better with torch-XLA this way
# run your script

Can you give that a try?, let me know what goes wrong if you try this way, also I recommend you use a new venv/conda environment to ensure there's no weird conflict this way. I should note I'm using Python=3.11 if that makes a difference.

@pkgoogle pkgoogle added status:awaiting user response When awaiting user response status:more data needed This label needs to be added to stale issues and PRs. type:support For use-related issues labels Sep 23, 2024
@rfechtner
Copy link
Author

rfechtner commented Sep 23, 2024

Hi @pkgoogle thanks for the swift reply.

I've created a clean env with your suggestions. Same behaviour: I can convert the PyTorch model just fine but the exported model will contain Int64 Tensors (as torch.max() returns LongTensor).

model_explorer_graph

but I want to avoid Int64 ops. I was trying to replace the torch function with TensorFlow ops, where I can specify the output dimension e.g.:

class ModelInt32TF(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    return tf.math.argmax(
          sample_inputs[0], axis=1, output_type=tf.int32
    )

model_int32_tf = ModelInt32TF().eval()
edge_model_int32_tf = ai_edge_torch.convert(model_int32_tf, sample_inputs)
edge_model_int32_tf(*sample_inputs)

which yields the error mentioned above:

---------------------------------------------------------------------------
Unsupported                               Traceback (most recent call last)
[<ipython-input-31-b7e7c53e3cf8>](https://v6zwhn4z3l-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240920-060127_RC00_676789073#) in <cell line: 11>()
      9 
     10 model_int32_tf = ModelInt32TF().eval()
---> 11 edge_model_int32_tf = ai_edge_torch.convert(model_int32_tf, sample_inputs)
     12 edge_model_int32_tf(*sample_inputs)

35 frames
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py](https://v6zwhn4z3l-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240920-060127_RC00_676789073#) in unimplemented(msg, from_exc)
    219     if from_exc is not _NOTHING:
    220         raise Unsupported(msg) from from_exc
--> 221     raise Unsupported(msg)
    222 
    223 

Unsupported: 'skip function argmax_v2 in file /usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py'

from user code:
   File "<ipython-input-31-b7e7c53e3cf8>", line 6, in forward
    return tf.math.argmax(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Note: I can replace

select = tensor.max(dim=1).indices.unsqueeze(0) by

select = np.emtpy(.., dtype=np.int32)
np.argmax(tensor, keepdims=1, out=select)

but torch.gather() and np.take_along_axis() (the later will be converted to the former) will keep requiring a Long tensor input...

@rfechtner
Copy link
Author

Using the np.argmax(..) instead of the tf.math.argmax() brings me a step further:

import ai_edge_torch
import torch

sample_inputs = (torch.randn(1, 3, 224, 224),)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    B, C, H, W = tensor.shape
    mode = np.empty((B, H, W), dtype=np.int32)
    np.argmax(tensor.detach().numpy(), axis=1, out=mode)
    mode = torch.from_numpy(mode).unsqueeze(0)

    return torch.gather(tensor, dim=1, index=mode.long())

model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
edge_model(*sample_inputs)

Allows me to create index tensor of dtype int32, but torch.gather() still requires LongTensor as input.

Environment: pip freeze

absl-py==1.4.0
accelerate==0.34.2
ai-edge-litert-nightly==1.0.1.dev20240924
ai-edge-model-explorer==0.1.12
ai-edge-model-explorer-adapter==0.1.5
ai-edge-quantizer-nightly==0.0.1.dev20240924
-e git+https://github.com/google-ai-edge/ai-edge-torch.git@c9973d2e7423e86f420576c0e5cac1181f79ac0e#egg=ai_edge_torch
aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiosignal==1.3.1
alabaster==0.7.16
albucore==0.0.16
albumentations==1.4.15
altair==4.2.2
annotated-types==0.7.0
anyio==3.7.1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.5.1
arviz==0.19.0
astropy==6.1.3
astropy-iers-data==0.2024.9.16.0.32.21
astunparse==1.6.3
async-timeout==4.0.3
atpublic==4.1.0
attrs==24.2.0
audioread==3.0.1
autograd==1.7.0
babel==2.16.0
backcall==0.2.0
beautifulsoup4==4.12.3
bidict==0.23.1
bigframes==1.17.0
bigquery-magics==0.2.0
bleach==6.1.0
blinker==1.4
blis==0.7.11
blosc2==2.0.0
bokeh==3.4.3
bqplot==0.12.43
branca==0.7.2
build==1.2.2
CacheControl==0.14.0
cachetools==5.5.0
catalogue==2.0.10
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.3.2
chex==0.1.86
clarabel==0.9.0
click==8.1.7
cloud-tpu-client==0.10
cloudpathlib==0.19.0
cloudpickle==2.2.1
cmake==3.30.3
cmdstanpy==1.2.4
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
community==1.0.0b1
confection==0.1.5
cons==0.4.6
contextlib2==21.6.0
contourpy==1.3.0
cryptography==43.0.1
cuda-python==12.2.1
cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.4.1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=57366e7ef09dc63e0b389aff20df6c37d91e2790065861ee31a4720149f5b694
cufflinks==0.17.3
cupy-cuda12x==12.2.0
cvxopt==1.3.2
cvxpy==1.5.3
cycler==0.12.1
cymem==2.0.8
Cython==3.0.11
dask==2024.8.0
datascience==0.17.6
db-dtypes==1.3.0
dbus-python==1.2.18
debugpy==1.6.6
decorator==4.4.2
defusedxml==0.7.1
distributed==2024.8.0
distro==1.7.0
dlib==19.24.2
dm-tree==0.1.8
docstring_parser==0.16
docutils==0.18.1
dopamine_rl==4.0.9
duckdb==1.1.0
earthengine-api==1.0.0
easydict==1.13
ecos==2.0.14
editdistance==0.8.1
eerepr==0.0.4
einops==0.8.0
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
entrypoints==0.4
et-xmlfile==1.1.0
etils==1.9.4
etuples==0.3.9
eval_type_backport==0.2.0
exceptiongroup==1.2.2
fastai==2.7.17
fastcore==1.7.8
fastdownload==0.0.7
fastjsonschema==2.20.0
fastprogress==1.0.3
fastrlock==0.8.2
filelock==3.16.1
firebase-admin==6.5.0
Flask==2.2.5
flatbuffers==24.3.25
flax==0.8.4
folium==0.17.0
fonttools==4.53.1
frozendict==2.4.4
frozenlist==1.4.1
fsspec==2024.6.1
future==1.0.0
gast==0.6.0
gcsfs==2024.6.1
GDAL==3.6.4
gdown==5.2.0
geemap==0.34.2
gensim==4.3.3
geocoder==1.38.1
geographiclib==2.0
geopandas==1.0.1
geopy==2.4.1
gin-config==0.5.0
glob2==0.7
google==2.0.3
google-ai-generativelanguage==0.6.6
google-api-core==1.34.1
google-api-python-client==1.8.0
google-auth==2.27.0
google-auth-httplib2==0.2.0
google-auth-oauthlib==1.2.1
google-cloud-aiplatform==1.67.1
google-cloud-bigquery==3.25.0
google-cloud-bigquery-connection==1.15.5
google-cloud-bigquery-storage==2.26.0
google-cloud-bigtable==2.26.0
google-cloud-core==2.4.1
google-cloud-datastore==2.19.0
google-cloud-firestore==2.16.1
google-cloud-functions==1.16.5
google-cloud-iam==2.15.2
google-cloud-language==2.13.4
google-cloud-pubsub==2.23.1
google-cloud-resource-manager==1.12.5
google-cloud-storage==2.8.0
google-cloud-translate==3.15.5
google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz#sha256=deb182392f5f78765ea686f1200ff7cfd42e31bdf8d172a68d6a29f657e1fe18
google-crc32c==1.6.0
google-generativeai==0.7.2
google-pasta==0.2.0
google-resumable-media==2.7.2
googleapis-common-protos==1.65.0
googledrivedownloader==0.4
graphviz==0.20.3
greenlet==3.1.0
grpc-google-iam-v1==0.13.1
grpcio==1.64.1
grpcio-status==1.48.2
gspread==6.0.2
gspread-dataframe==3.3.1
gym==0.25.2
gym-notices==0.0.8
h5netcdf==1.3.0
h5py==3.11.0
holidays==0.57
holoviews==1.19.1
html5lib==1.1
httpimport==1.4.0
httplib2==0.22.0
huggingface-hub==0.24.7
humanize==4.10.0
hyperopt==0.2.7
ibis-framework==8.0.0
idna==3.10
imageio==2.35.1
imageio-ffmpeg==0.5.1
imagesize==1.4.1
imbalanced-learn==0.12.3
imgaug==0.4.0
immutabledict==4.2.0
importlib_metadata==8.5.0
importlib_resources==6.4.5
imutils==0.5.4
inflect==7.4.0
iniconfig==2.0.0
intel-cmplr-lib-ur==2024.2.1
intel-openmp==2024.2.1
ipyevents==2.0.2
ipyfilechooser==0.6.0
ipykernel==5.5.6
ipyleaflet==0.19.2
ipyparallel==8.8.0
ipython==7.34.0
ipython-genutils==0.2.0
ipython-sql==0.5.0
ipytree==0.2.2
ipywidgets==7.7.1
itsdangerous==2.2.0
jax==0.4.26
jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750
jedi==0.19.1
jeepney==0.7.1
jellyfish==1.1.0
jieba==0.42.1
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.3.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter-client==6.1.12
jupyter-console==6.1.0
jupyter-leaflet==0.19.2
jupyter-server==1.24.0
jupyter_core==5.7.2
jupyterlab_pygments==0.3.0
jupyterlab_widgets==3.0.13
kaggle==1.6.17
kagglehub==0.3.0
keras==3.4.1
keras-nightly==3.5.0.dev2024092403
keyring==23.5.0
kiwisolver==1.4.7
langcodes==3.4.0
language_data==1.2.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1
lightgbm==4.5.0
linkify-it-py==2.0.3
llvmlite==0.43.0
locket==1.0.0
logical-unification==0.4.6
lxml==4.9.4
marisa-trie==1.2.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.1
matplotlib-inline==0.1.7
matplotlib-venn==1.1.1
mdit-py-plugins==0.4.2
mdurl==0.1.2
miniKanren==1.0.3
missingno==0.5.2
mistune==0.8.4
mizani==0.11.4
mkl==2024.2.2
ml-dtypes==0.4.1
mlxtend==0.23.1
more-itertools==10.5.0
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.0.8
multidict==6.1.0
multipledispatch==1.0.0
multitasking==0.0.11
murmurhash==1.0.10
music21==9.1.0
namex==0.0.8
natsort==8.4.0
nbclassic==1.1.0
nbclient==0.10.0
nbconvert==6.5.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
nibabel==5.2.1
nltk==3.8.1
notebook==6.5.5
notebook_shim==0.2.4
numba==0.60.0
numexpr==2.10.1
numpy==1.26.4
nvidia-nccl-cu12==2.23.4
nvtx==0.2.10
oauth2client==4.1.3
oauthlib==3.2.2
opencv-contrib-python==4.10.0.84
opencv-python==4.10.0.84
opencv-python-headless==4.10.0.84
openpyxl==3.1.5
opt-einsum==3.3.0
optax==0.2.2
optree==0.12.1
orbax-checkpoint==0.6.4
osqp==0.6.7.post0
packaging==24.1
pandas==2.1.4
pandas-datareader==0.10.0
pandas-gbq==0.23.1
pandas-stubs==2.1.4.231227
pandocfilters==1.5.1
panel==1.4.5
param==2.1.1
parso==0.8.4
parsy==2.1
partd==1.4.2
pathlib==1.0.1
patsy==0.5.6
peewee==3.17.6
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.4.0
pip-tools==7.4.1
platformdirs==4.3.6
plotly==5.24.1
plotnine==0.13.6
pluggy==1.5.0
polars==1.6.0
pooch==1.8.2
portpicker==1.5.2
prefetch_generator==1.0.3
preshed==3.0.9
prettytable==3.11.0
proglog==0.1.10
progressbar2==4.5.0
prometheus_client==0.20.0
promise==2.3
prompt_toolkit==3.0.47
prophet==1.1.5
proto-plus==1.24.0
protobuf==3.20.3
psutil==5.9.5
psycopg2==2.9.9
ptyprocess==0.7.0
py-cpuinfo==9.0.0
py4j==0.10.9.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycocotools==2.0.8
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
pydata-google-auth==1.8.2
pydot==3.0.1
pydot-ng==2.0.0
pydotplus==2.0.2
PyDrive==1.3.1
PyDrive2==1.20.0
pyerfa==2.0.1.4
pygame==2.6.0
Pygments==2.18.0
PyGObject==3.42.1
PyJWT==2.9.0
pymc==5.16.2
pymystem3==0.2.0
pynvjitlink-cu12==0.3.0
pyogrio==0.9.0
PyOpenGL==3.1.7
pyOpenSSL==24.2.1
pyparsing==3.1.4
pyperclip==1.9.0
pyproj==3.6.1
pyproject_hooks==1.1.0
pyshp==2.3.1
PySocks==1.7.1
pytensor==2.25.4
pytest==7.4.4
python-apt==2.4.0
python-box==7.2.0
python-dateutil==2.8.2
python-louvain==0.16
python-slugify==8.0.4
python-utils==3.8.2
pytz==2024.2
pyviz_comms==3.0.3
PyYAML==6.0.2
pyzmq==24.0.1
qdldl==0.1.7.post4
ratelim==0.1.6
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
requests-oauthlib==1.3.1
requirements-parser==0.9.0
rich==13.8.1
rmm-cu12==24.4.0
rpds-py==0.20.0
rpy2==3.4.2
rsa==4.9
safetensors==0.4.5
scikit-image==0.24.0
scikit-learn==1.5.2
scipy==1.13.1
scooby==0.10.0
scs==3.2.7
seaborn==0.13.1
SecretStorage==3.3.1
Send2Trash==1.8.3
sentencepiece==0.2.0
shapely==2.0.6
shellingham==1.5.4
simple-parsing==0.1.6
six==1.16.0
sklearn-pandas==2.2.0
smart-open==7.0.4
sniffio==1.3.1
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.6
soxr==0.5.0.post1
spacy==3.7.6
spacy-legacy==3.0.12
spacy-loggers==1.0.5
Sphinx==5.0.2
sphinxcontrib-applehelp==2.0.0
sphinxcontrib-devhelp==2.0.0
sphinxcontrib-htmlhelp==2.1.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==2.0.0
sphinxcontrib-serializinghtml==2.0.0
SQLAlchemy==2.0.35
sqlglot==20.11.0
sqlparse==0.5.1
srsly==2.4.8
stanio==0.5.1
statsmodels==0.14.3
StrEnum==0.4.15
sympy==1.13.3
tables==3.8.0
tabulate==0.9.0
tb-nightly==2.18.0a20240924
tbb==2021.13.1
tblib==3.0.0
tenacity==9.0.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
tensorflow-cpu==2.17.0
tensorflow-datasets==4.9.6
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.37.1
tensorflow-metadata==1.15.0
tensorflow-probability==0.24.0
tensorstore==0.1.65
termcolor==2.4.0
terminado==0.18.1
text-unidecode==1.3
textblob==0.17.1
tf-slim==1.1.0
tf_keras==2.17.0
tf_nightly==2.18.0.dev20240923
thinc==8.2.5
threadpoolctl==3.5.0
tifffile==2024.8.30
tinycss2==1.3.0
tokenizers==0.19.1
toml==0.10.2
tomli==2.0.1
toolz==0.12.1
torch==2.4.0+cpu
torch-xla==2.4.0
torchaudio==2.4.0+cpu
torchsummary==1.5.1
torchvision==0.19.0+cpu
tornado==6.3.3
tqdm==4.66.5
traitlets==5.7.1
traittypes==0.2.1
transformers==4.44.2
tweepy==4.14.0
typeguard==4.3.0
typer==0.12.5
types-pytz==2024.2.0.20240913
types-setuptools==75.1.0.20240917
typing_extensions==4.12.2
tzdata==2024.1
tzlocal==5.2
uc-micro-py==1.0.3
uritemplate==3.0.1
urllib3==2.2.3
vega-datasets==0.9.0
wadllib==1.3.6
wasabi==1.1.3
wcwidth==0.2.13
weasel==0.4.1
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.0.4
widgetsnbextension==3.6.9
wordcloud==1.9.3
wrapt==1.16.0
xarray==2024.9.0
xarray-einstats==0.8.0
xgboost==2.1.1
xlrd==2.0.1
xyzservices==2024.9.0
yarl==1.11.1
yellowbrick==1.5
yfinance==0.2.43
zict==3.0.0
zipp==3.20.2

@rfechtner
Copy link
Author

rfechtner commented Sep 24, 2024

If I replace the torch.gather() by advanced indexing, I still get a Int64 OP Less that seems to be introduced for slicing:

import ai_edge_torch
import torch

sample_inputs = (torch.randn(1, 3, 224, 224),)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    B, C, H, W = tensor.shape
    mode = np.empty((B, H, W), dtype=np.int32)
    np.argmax(tensor.detach().numpy(), axis=1, out=mode)
    mode = torch.from_numpy(mode) # (1, H, W)

    collected = torch.empty((B, 1, H, W), dtype=tensor.dtype, device=tensor.device)
    for b in range(B):
      collected[b, 0] = tensor[
          torch.arange(B, dtype=torch.int32).unsqueeze(-1).unsqueeze(-1),
          mode,
          torch.arange(H, dtype=torch.int32),
          torch.arange(W, dtype=torch.int32)
      ]
    return collected

model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
result = edge_model(*sample_inputs)
print(f"Output: {result.shape}")
edge_model.export("fancy.tflite")

model_explorer_graph-4

@pkgoogle pkgoogle removed status:awaiting user response When awaiting user response status:more data needed This label needs to be added to stale issues and PRs. labels Sep 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:support For use-related issues
Projects
None yet
Development

No branches or pull requests

2 participants