-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace Int64 with Int32 for edge #246
Comments
Hi @rfechtner, I was actually not able to replicate this issue if I use the latest code in main i.e.: # navigate to ai-edge-torch repo
git switch main # if not already in the main branch
git pull # update to latest code
pip install -e .
pip install tensorflow-cpu # There was an import conflict that the latest code works better with torch-XLA this way
# run your script Can you give that a try?, let me know what goes wrong if you try this way, also I recommend you use a new venv/conda environment to ensure there's no weird conflict this way. I should note I'm using Python=3.11 if that makes a difference. |
Hi @pkgoogle thanks for the swift reply. I've created a clean env with your suggestions. Same behaviour: I can convert the PyTorch model just fine but the exported model will contain Int64 Tensors (as but I want to avoid Int64 ops. I was trying to replace the torch function with TensorFlow ops, where I can specify the output dimension e.g.:
which yields the error mentioned above:
Note: I can replace
but torch.gather() and np.take_along_axis() (the later will be converted to the former) will keep requiring a Long tensor input... |
Using the
Allows me to create index tensor of dtype
Environment: pip freeze
|
If I replace the
|
Description of the bug:
Hi,
I am trying to covert an PyTorch to TFLite which uses
torch.argmax(..).indicies
andtorch.gather(..
) - hence creatingLongTensor
s (Int64). As my targeted runtime delegate does not support any int64 ops (including cast int64 -> int32), I am seeking to replace int64 ops by corresponding int32 ones.Min rep. example:
In the past I have been dong this via intermediate ONNX model representation where I modified the relevant nodes and then converted ONNX to TFLite, but with this new framework I’d hoped to get rid of the onnx.
I have tried to replace the
torch.argmax()
with atf.math.argmax(.., output_type=tf.int32)
or the numpy equivalent which supports specifying the output type or array, but that fails duringtorch.export()
and results inOne remaining avenue I can think of is post processing the resulting flatbuffer representation and replacing the int64 ops here, but that seems quite brittle and overly complicated.
Any other suggestions? Or is there a way do dynamically replace functions?
Note: I had to pin
tf-nightly==2.18.0.dev20240722
otherwise the export fails with:Click this to collapse/fold.
The text was updated successfully, but these errors were encountered: