Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: keras 3 backend for ivy #28794

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,8 @@ def func(x):
target_backend is not None
and ivy.backend != ""
and ivy.current_backend_str() != target_backend.backend
# keras supports inputs instantiated with different backends
and ivy.current_backend_str() != "keras"
):
raise ivy.utils.exceptions.IvyInvalidBackendException(
"Operation not allowed. Array was instantiated with backend"
Expand Down
294 changes: 294 additions & 0 deletions ivy/functional/backends/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# global
import os
import sys
import keras


backend = "keras"
backend_version = {"version": keras.__version__}
keras_backend = os.getenv("KERAS_BACKEND", default="tensorflow").lower()

# local
import ivy
from ivy.func_wrapper import _dtype_from_version


# noinspection PyUnresolvedReferences
if not ivy.is_local():
_module_in_memory = sys.modules[__name__]
else:
_module_in_memory = sys.modules[ivy.import_module_path].import_cache[__name__]

use = ivy.utils.backend.ContextManager(_module_in_memory)


# wrap dunder methods of native tensors to return NotImplemented to prioritize Ivy array methods.
def dunder_wrapper(func):
def rep_method(*args, **kwargs):
for arg in args:
if ivy.is_ivy_array(arg):
return NotImplemented
return func(*args, **kwargs)

return rep_method


# check for previously imported tensorflow modules
modules_to_patch = []
tensors_to_patch = []
tmp_globals = dict(globals())
for name, value in tmp_globals.items():
if value == "tensorflow.python.framework.ops.Tensor":
tensors_to_patch.append(name)
try:
if value.__name__ == "tensorflow":
modules_to_patch.append(name)
except AttributeError:
pass

methods_to_patch = [
"__add__",
"__sub__",
"__mul__",
"__div__",
"__truediv__",
"__floordiv__",
"__mod__",
"__lt__",
"__le__",
"__gt__",
"__ge__",
"__ne__",
"__eq__",
"__and__",
"__or__",
"__xor__",
"__pow__",
"__matmul__",
]

for module in modules_to_patch:
for method in methods_to_patch:
exec(
module
+ ".Tensor."
+ method
+ " = dunder_wrapper("
+ module
+ ".Tensor."
+ method
+ ")"
)

for tensor in tensors_to_patch:
for method in methods_to_patch:
exec(tensor + "." + method + " = dunder_wrapper(" + tensor + "." + method + ")")



print('setting NativeArray keras') ###########
if keras_backend == "jax":
import jax
import jax.numpy as jnp
import jaxlib

if jax.__version__ >= "0.4.1":
JaxArray = jax.Array
NativeArray = jax.Array
else:
JaxArray = jaxlib.xla_extension.DeviceArray
NativeArray = jaxlib.xla_extension.DeviceArray

# noinspection PyUnresolvedReferences,PyProtectedMember
NativeDevice = jaxlib.xla_extension.Device
NativeDtype = jnp.dtype
NativeShape = tuple

NativeSparseArray = None
elif keras_backend == "torch":
import torch

NativeArray = torch.Tensor
NativeDevice = torch.device
NativeDtype = torch.dtype
NativeShape = torch.Size
NativeSparseArray = torch.Tensor
else:
import tensorflow as tf
from tensorflow.python.framework.dtypes import DType
from tensorflow.python.framework.tensor_shape import TensorShape
from tensorflow.python.types.core import Tensor

NativeArray = Tensor
NativeDevice = str
NativeDtype = DType
NativeShape = TensorShape
NativeSparseArray = tf.SparseTensor


# devices
valid_devices = ("cpu", "gpu", "tpu")

# native data types
native_int8 = tf.int8
native_int16 = tf.int16
native_int32 = tf.int32
native_int64 = tf.int64
native_uint8 = tf.uint8
native_uint16 = tf.uint16
native_uint32 = tf.uint32
native_uint64 = tf.uint64
native_bfloat16 = tf.bfloat16
native_float16 = tf.float16
native_float32 = tf.float32
native_float64 = tf.float64
native_complex64 = tf.complex64
native_complex128 = tf.complex128
native_double = native_float64
native_bool = tf.bool

# valid data types
# ToDo: Add complex dtypes to valid_dtypes and fix all resulting failures.

# update these to add new dtypes
valid_dtypes = {
"3.4.1 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
ivy.int64,
ivy.uint8,
ivy.uint16,
ivy.uint32,
ivy.uint64,
ivy.bfloat16,
ivy.float16,
ivy.float32,
ivy.float64,
ivy.complex64,
ivy.complex128,
ivy.bool,
)
}
valid_numeric_dtypes = {
"3.4.1 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
ivy.int64,
ivy.uint8,
ivy.uint16,
ivy.uint32,
ivy.uint64,
ivy.bfloat16,
ivy.float16,
ivy.float32,
ivy.float64,
ivy.complex64,
ivy.complex128,
)
}
valid_int_dtypes = {
"3.4.1 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
ivy.int64,
ivy.uint8,
ivy.uint16,
ivy.uint32,
ivy.uint64,
)
}
valid_float_dtypes = {
"3.4.1 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
}
valid_uint_dtypes = {
"3.4.1 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64)
}
valid_complex_dtypes = {"3.4.1 and below": (ivy.complex128,)}

# leave these untouched
valid_dtypes = _dtype_from_version(valid_dtypes, backend_version)
valid_numeric_dtypes = _dtype_from_version(valid_numeric_dtypes, backend_version)
valid_int_dtypes = _dtype_from_version(valid_int_dtypes, backend_version)
valid_float_dtypes = _dtype_from_version(valid_float_dtypes, backend_version)
valid_uint_dtypes = _dtype_from_version(valid_uint_dtypes, backend_version)
valid_complex_dtypes = _dtype_from_version(valid_complex_dtypes, backend_version)

# invalid data types
# update these to add new dtypes
invalid_dtypes = {"3.4.1 and below": ()}
invalid_numeric_dtypes = {"3.4.1 and below": ()}
invalid_int_dtypes = {"3.4.1 and below": ()}
invalid_float_dtypes = {"3.4.1 and below": ()}
invalid_uint_dtypes = {"3.4.1 and below": ()}
invalid_complex_dtypes = {"3.4.1 and below": ()}

# leave these untouched
invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version)
invalid_numeric_dtypes = _dtype_from_version(invalid_numeric_dtypes, backend_version)
invalid_int_dtypes = _dtype_from_version(invalid_int_dtypes, backend_version)
invalid_float_dtypes = _dtype_from_version(invalid_float_dtypes, backend_version)
invalid_uint_dtypes = _dtype_from_version(invalid_uint_dtypes, backend_version)
invalid_complex_dtypes = _dtype_from_version(invalid_complex_dtypes, backend_version)

native_inplace_support = False

supports_gradients = True


def closest_valid_dtype(type=None, /, as_native=False):
if type is None:
type = ivy.default_dtype()
return ivy.as_ivy_dtype(type) if not as_native else ivy.as_native_dtype(type)


# local sub-modules
from . import activations
from .activations import *
from . import creation
from .creation import *
from . import data_type
from .data_type import *
from . import device
from .device import *
from . import elementwise
from .elementwise import *
from . import general
from .general import *
from . import gradients
from .gradients import *
from . import layers
from .layers import *
from . import linear_algebra as linalg
from .linear_algebra import *
from . import manipulation
from .manipulation import *
from . import random
from .random import *
from . import searching
from .searching import *
from . import set
from .set import *
from . import sorting
from .sorting import *
from . import statistical
from .statistical import *
from . import utility
from .utility import *
from . import experimental
from .experimental import *
from . import control_flow_ops
from .control_flow_ops import *


# sub-backends
from . import sub_backends
from .sub_backends import *

from . import module
# from .module import Model


# NativeModule = Model
Empty file.
Empty file.
Loading
Loading