Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fix tpu support to torch backend, unsupported data-types #28739

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions ivy/functional/backends/paddle/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import ivy.functional.backends.paddle as paddle_backend
from typing import Optional, Union, Sequence

# local
import ivy
from paddle.device import core
from ivy.functional.ivy.random import (
Expand Down Expand Up @@ -45,7 +44,6 @@ def random_uniform(
low = paddle.cast(low, "float32") if isinstance(low, paddle.Tensor) else low
high = paddle.cast(high, "float32") if isinstance(high, paddle.Tensor) else high
shape = _check_bounds_and_get_shape(low, high, shape).shape
# Set range and seed
rng = high - low
if seed:
_ = paddle.seed(seed)
Expand All @@ -57,7 +55,8 @@ def random_uniform(


@with_unsupported_dtypes(
{"2.6.0 and below": ("float16", "int16", "int8")}, backend_version
{"2.6.0 and below": ("float16", "int16", "int8")},
backend_version,
)
def random_normal(
*,
Expand Down Expand Up @@ -155,10 +154,67 @@ def shuffle(
) -> paddle.Tensor:
if seed:
_ = paddle.seed(seed)
# Use Paddle's randperm function to generate shuffled indices
indices = paddle.randperm(x.ndim, dtype="int64")
if paddle.is_complex(x):
shuffled_real = paddle.index_select(x.real(), indices, axis=axis)
shuffled_imag = paddle.index_select(x.imag(), indices, axis=axis)
return paddle.complex(shuffled_real, shuffled_imag)
return paddle.index_select(x, indices, axis=axis)


# New Random Distribution Functions
# -----------------------------------


def random_exponential(
*,
scale: Union[float, paddle.Tensor],
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
dtype: paddle.dtype,
seed: Optional[int] = None,
) -> paddle.Tensor:
_check_valid_scale(scale)
shape = _check_bounds_and_get_shape(scale, None, shape).shape
if seed:
paddle.seed(seed)
return paddle.exponential(scale, shape).cast(dtype)


def random_poisson(
*,
lam: Union[float, paddle.Tensor],
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
dtype: paddle.dtype,
seed: Optional[int] = None,
) -> paddle.Tensor:
shape = _check_bounds_and_get_shape(lam, None, shape).shape
if seed:
paddle.seed(seed)
return paddle.poisson(lam, shape).cast(dtype)


def random_bernoulli(
*,
p: Union[float, paddle.Tensor],
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
dtype: paddle.dtype,
seed: Optional[int] = None,
) -> paddle.Tensor:
shape = _check_bounds_and_get_shape(p, None, shape).shape
if seed:
paddle.seed(seed)
return paddle.bernoulli(p, shape).cast(dtype)


def random_beta(
*,
alpha: Union[float, paddle.Tensor],
beta: Union[float, paddle.Tensor],
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
dtype: paddle.dtype,
seed: Optional[int] = None,
) -> paddle.Tensor:
shape = _check_bounds_and_get_shape(alpha, beta, shape).shape
if seed:
paddle.seed(seed)
return paddle.beta(alpha, beta, shape).cast(dtype)
112 changes: 23 additions & 89 deletions ivy/functional/backends/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,29 @@
if hasattr(torch, "_dynamo"):
torch._dynamo.config.traceable_tensor_subclasses = (ivy.Array,)

# noinspection PyUnresolvedReferences
if not ivy.is_local():
_module_in_memory = sys.modules[__name__]
else:
_module_in_memory = sys.modules[ivy.import_module_path].import_cache[__name__]
# Determine the module in memory based on whether Ivy is local or not
_module_in_memory = (
sys.modules[__name__]
if not ivy.is_local()
else sys.modules[ivy.import_module_path].import_cache[__name__]
)

use = ivy.utils.backend.ContextManager(_module_in_memory)

# Native types
NativeArray = torch.Tensor
NativeDevice = torch.device
NativeDtype = torch.dtype
NativeShape = torch.Size

# Sparse array
NativeSparseArray = torch.Tensor


# devices
# Devices
valid_devices = ("cpu", "gpu")

invalid_devices = ("tpu",)


# native data types
# Native data types
native_int8 = torch.int8
native_int16 = torch.int16
native_int32 = torch.int32
Expand All @@ -46,13 +46,9 @@
native_float64 = torch.float64
native_complex64 = torch.complex64
native_complex128 = torch.complex128
native_double = native_float64
native_bool = torch.bool

# valid data types
# ToDo: Add complex dtypes to valid_dtypes and fix all resulting failures.

# update these to add new dtypes
# Valid and invalid data types
valid_dtypes = {
"2.2 and below": (
ivy.int8,
Expand All @@ -70,68 +66,29 @@
)
}


valid_numeric_dtypes = {
"2.2 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
ivy.int64,
ivy.uint8,
ivy.bfloat16,
ivy.float16,
ivy.float32,
ivy.float64,
ivy.complex64,
ivy.complex128,
)
}

valid_int_dtypes = {
"2.2 and below": (ivy.int8, ivy.int16, ivy.int32, ivy.int64, ivy.uint8)
}
valid_float_dtypes = {
"2.2 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
}
valid_uint_dtypes = {"2.2 and below": (ivy.uint8,)}
valid_complex_dtypes = {"2.2 and below": (ivy.complex64, ivy.complex128)}

# leave these untouched
# Update valid_dtypes based on backend_version
valid_dtypes = _dtype_from_version(valid_dtypes, backend_version)
valid_numeric_dtypes = _dtype_from_version(valid_numeric_dtypes, backend_version)
valid_int_dtypes = _dtype_from_version(valid_int_dtypes, backend_version)
valid_float_dtypes = _dtype_from_version(valid_float_dtypes, backend_version)
valid_uint_dtypes = _dtype_from_version(valid_uint_dtypes, backend_version)
valid_complex_dtypes = _dtype_from_version(valid_complex_dtypes, backend_version)

# invalid data types
# update these to add new dtypes

# Invalid data types
invalid_dtypes = {
"2.2 and below": (
ivy.uint16,
ivy.uint32,
ivy.uint64,
)
}
invalid_numeric_dtypes = {"2.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)}
invalid_int_dtypes = {"2.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)}
invalid_float_dtypes = {"2.2 and below": ()}
invalid_uint_dtypes = {"2.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)}
invalid_complex_dtypes = {"2.2 and below": ()}

# Update invalid_dtypes based on backend_version
invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version)

# leave these untouched
invalid_numeric_dtypes = _dtype_from_version(invalid_numeric_dtypes, backend_version)
invalid_int_dtypes = _dtype_from_version(invalid_int_dtypes, backend_version)
invalid_float_dtypes = _dtype_from_version(invalid_float_dtypes, backend_version)
invalid_uint_dtypes = _dtype_from_version(invalid_uint_dtypes, backend_version)
invalid_complex_dtypes = _dtype_from_version(invalid_complex_dtypes, backend_version)
# Unsupported devices
unsupported_devices = ("tpu",)

native_inplace_support = True

supports_gradients = True


# Closest valid dtype function
def closest_valid_dtype(type=None, /, as_native=False):
if type is None:
type = ivy.default_dtype()
Expand All @@ -145,6 +102,7 @@ def closest_valid_dtype(type=None, /, as_native=False):
backend = "torch"


# Globals getter function
def globals_getter_func(x=None):
if not x:
return globals()
Expand All @@ -153,55 +111,31 @@ def globals_getter_func(x=None):


ivy.func_wrapper.globals_getter_func = globals_getter_func
# local sub-modules

# Import sub-modules
from . import activations
from .activations import *


from . import creation
from .creation import *
from . import data_type
from .data_type import *
from . import device
from .device import *
from . import elementwise
from .elementwise import *
from . import gradients
from .gradients import *
from . import general
from .general import *
from . import layers
from .layers import *
from . import linear_algebra as linalg
from .linear_algebra import *
from . import manipulation
from .manipulation import *
from . import random
from .random import *
from . import searching
from .searching import *
from . import set
from .set import *
from . import sorting
from .sorting import *
from . import statistical
from .statistical import *
from . import utility
from .utility import *
from . import experimental
from .experimental import *
from . import control_flow_ops
from .control_flow_ops import *
from . import norms
from .norms import *
from . import module
from .module import *


# sub-backends
# Import sub-backends
from . import sub_backends
from .sub_backends import *


# Native module
NativeModule = torch.nn.Module
15 changes: 12 additions & 3 deletions ivy/functional/backends/torch/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@
Profiler as BaseProfiler,
)

# Invalid data types
invalid_dtypes = {
"2.2 and below": (
ivy.uint16,
ivy.uint32,
ivy.uint64,
)
}

# Unsupported devices
unsupported_devices = ("tpu",)

torch_scatter = None

# API #
Expand Down Expand Up @@ -103,7 +115,6 @@ def gpu_is_available() -> bool:
) or torch.cuda.is_available()


# noinspection PyUnresolvedReferences
def tpu_is_available() -> bool:
if importlib.util.find_spec("torch_xla") is not None:
return True
Expand All @@ -114,8 +125,6 @@ def handle_soft_device_variable(*args, fn, **kwargs):
args, kwargs, device_shifting_dev = _shift_native_arrays_on_default_device(
*args, **kwargs
)
# checking if this function accepts `device` argument
# must be handled in the backend
if "device" in inspect.signature(fn).parameters:
kwargs["device"] = device_shifting_dev
return fn(*args, **kwargs)
Expand Down
Loading
Loading