From 23319746e483012e6d3a7f2e4048cdbe7c458fa5 Mon Sep 17 00:00:00 2001 From: Eric Meier Date: Tue, 30 Jul 2024 10:29:22 -0700 Subject: [PATCH] Add support for xla devices (#494) * Add support for xla devices * :broom: --- bindings/python/src/lib.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index e0a52524..2c05ad56 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -266,6 +266,7 @@ enum Device { Mps, Npu(usize), Xpu(usize), + Xla(usize), } impl<'source> FromPyObject<'source> for Device { @@ -277,6 +278,7 @@ impl<'source> FromPyObject<'source> for Device { "mps" => Ok(Device::Mps), "npu" => Ok(Device::Npu(0)), "xpu" => Ok(Device::Xpu(0)), + "xla" => Ok(Device::Xla(0)), name if name.starts_with("cuda:") => { let tokens: Vec<_> = name.split(':').collect(); if tokens.len() == 2 { @@ -310,6 +312,17 @@ impl<'source> FromPyObject<'source> for Device { ))) } } + name if name.starts_with("xla:") => { + let tokens: Vec<_> = name.split(':').collect(); + if tokens.len() == 2 { + let device: usize = tokens[1].parse()?; + Ok(Device::Xla(device)) + } else { + Err(SafetensorError::new_err(format!( + "device {name} is invalid" + ))) + } + } name => Err(SafetensorError::new_err(format!( "device {name} is invalid" ))), @@ -330,6 +343,7 @@ impl IntoPy for Device { Device::Mps => "mps".into_py(py), Device::Npu(n) => format!("npu:{n}").into_py(py), Device::Xpu(n) => format!("xpu:{n}").into_py(py), + Device::Xla(n) => format!("xla:{n}").into_py(py), } } }