Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for xla devices #494

Merged
merged 2 commits into from
Jul 30, 2024
Merged

Add support for xla devices #494

merged 2 commits into from
Jul 30, 2024

Conversation

BitPhinix
Copy link
Contributor

@BitPhinix BitPhinix commented Jul 6, 2024

What does this PR do?

Adds support for xla devices via torch_xla (https://pytorch.org/xla/release/2.3/index.html)

Test code:

import torch
import torch_xla.core.xla_model as xm
from safetensors.torch import save_file
from safetensors import safe_open

test = {"a": torch.tensor([1, 2, 3])}

save_file(test, 'test.safetensors')

with safe_open("test.safetensors", framework="pt", device="xla") as f:
   for key in f.keys():
      print(f.get_tensor(key).device)

@Narsil
Copy link
Collaborator

Narsil commented Jul 30, 2024

torch_xla had a lot of issues and different API based on the actual accelerator.

Given the narrow PR scope however I feel like this is a good addition if that enables XLA.
But this will not be tested.

@Narsil Narsil merged commit 2331974 into huggingface:main Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants