Skip to content

Commit

Permalink
device fix to allow empty string (#2338)
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy authored Jul 27, 2022
1 parent 52496c6 commit ef379b6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
13 changes: 8 additions & 5 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ class NativeShape:

class Device(str):
def __new__(cls, dev_str):
assert dev_str[0:3] in ["gpu", "tpu", "cpu"]
if dev_str != "cpu":
assert dev_str[3] == ":"
assert dev_str[4:].isnumeric()
if dev_str != "":
assert dev_str[0:3] in ["gpu", "tpu", "cpu"]
if dev_str != "cpu":
assert dev_str[3] == ":"
assert dev_str[4:].isnumeric()
return str.__new__(cls, dev_str)


Expand All @@ -74,7 +75,9 @@ def __new__(cls, shape_tup):
shape_tup = (shape_tup,)
elif isinstance(shape_tup, list):
shape_tup = tuple(shape_tup)
assert builtins.all([isinstance(v, int) for v in shape_tup])
assert builtins.all(
[isinstance(v, int) or ivy.is_int_dtype(v.dtype) for v in shape_tup]
)
if ivy.shape_array_mode():
return ivy.array(shape_tup)
return tuple.__new__(cls, shape_tup)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def dev(
x: JaxArray, as_native: bool = False
) -> Union[ivy.Device, jaxlib.xla_extension.Device]:
if isinstance(x, jax.interpreters.partial_eval.DynamicJaxprTracer):
return None
return ""
try:
dv = _to_array(x).device_buffer.device
dv = dv()
Expand Down

0 comments on commit ef379b6

Please sign in to comment.