Skip to content

Commit

Permalink
Adding support for Ascend NPU (#372)
Browse files Browse the repository at this point in the history
* Adding support for Ascend NPU

* remove the unnecessary hack code

* test more dtype

* npu doesn't support calling torch.allclose with bf16 for now
  • Loading branch information
statelesshz authored Nov 20, 2023
1 parent 829bfa8 commit 094e676
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
14 changes: 14 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ enum Device {
Cpu,
Cuda(usize),
Mps,
Npu(usize),
}

impl<'source> FromPyObject<'source> for Device {
Expand All @@ -229,6 +230,7 @@ impl<'source> FromPyObject<'source> for Device {
"cpu" => Ok(Device::Cpu),
"cuda" => Ok(Device::Cuda(0)),
"mps" => Ok(Device::Mps),
"npu" => Ok(Device::Npu(0)),
name if name.starts_with("cuda:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
Expand All @@ -240,6 +242,17 @@ impl<'source> FromPyObject<'source> for Device {
)))
}
}
name if name.starts_with("npu:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(Device::Npu(device))
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}
name => Err(SafetensorError::new_err(format!(
"device {name} is invalid"
))),
Expand All @@ -258,6 +271,7 @@ impl IntoPy<PyObject> for Device {
Device::Cpu => "cpu".into_py(py),
Device::Cuda(n) => format!("cuda:{n}").into_py(py),
Device::Mps => "mps".into_py(py),
Device::Npu(n) => format!("npu:{n}").into_py(py),
}
}
}
Expand Down
21 changes: 21 additions & 0 deletions bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from safetensors.torch import load, load_file, save, save_file


try:
import torch_npu # noqa

npu_present = True
except Exception:
npu_present = False


class TorchTestCase(unittest.TestCase):
def test_serialization(self):
data = torch.zeros((2, 2), dtype=torch.int32)
Expand Down Expand Up @@ -119,6 +127,19 @@ def test_gpu(self):
reloaded = load_file(local)
self.assertTrue(torch.equal(torch.arange(4).view((2, 2)), reloaded["test"]))

@unittest.skipIf(not npu_present, "Npu is not available")
def test_npu(self):
data = {
"test1": torch.zeros((2, 2), dtype=torch.float32).to("npu:0"),
"test2": torch.zeros((2, 2), dtype=torch.float16).to("npu:0"),
}
local = "./tests/data/out_safe_pt_mmap_small_npu.safetensors"
save_file(data, local)

reloaded = load_file(local, device="npu:0")
for k, v in reloaded.items():
self.assertTrue(torch.allclose(data[k], reloaded[k]))

def test_sparse(self):
data = {"test": torch.sparse_coo_tensor(size=(2, 3))}
local = "./tests/data/out_safe_pt_sparse.safetensors"
Expand Down

0 comments on commit 094e676

Please sign in to comment.