Skip to content

Commit

Permalink
chore: fix backend tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Jun 21, 2023
1 parent 8373103 commit 82a1e5c
Showing 1 changed file with 60 additions and 9 deletions.
69 changes: 60 additions & 9 deletions keras_core/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope

DYNAMIC_SHAPES_OK = False
DYNAMIC_SHAPES_OK = True

NUMPY_DTYPES = {
"float16": np.float16,
"float32": np.float32,
"float64": np.float64,
"uint8": np.uint8,
"uint16": np.uint16,
"uint32": np.uint32,
"int8": np.int8,
"int16": np.int16,
"int32": np.int32,
"int64": np.int64,
}


class Variable(KerasVariable):
Expand Down Expand Up @@ -75,19 +88,57 @@ def vectorized_map(function, elements):
def compute_output_spec(fn, *args, **kwargs):
with StatelessScope():

def convert_keras_tensor_to_numpy(x):
def has_none_shape(x):
if isinstance(x, KerasTensor):
return np.ones(x.shape, dtype=x.dtype)
return None in x.shape
return False

none_in_shape = any(map(has_none_shape, nest.flatten((args, kwargs))))

def convert_keras_tensor_to_numpy(x, fill_value=None):
if isinstance(x, KerasTensor):
shape = list(x.shape)
if fill_value:
for i, e in enumerate(shape):
if e is None:
shape[i] = fill_value
return np.empty(
shape=shape,
dtype=NUMPY_DTYPES[x.dtype],
)
return x

args, kwargs = nest.map_structure(
convert_keras_tensor_to_numpy, (args, kwargs)
args_1, kwargs_1 = nest.map_structure(
lambda x: convert_keras_tensor_to_numpy(x, fill_value=83),
(args, kwargs),
)
np_out = fn(*args, **kwargs)
outputs_1 = fn(*args_1, **kwargs_1)

outputs = outputs_1

if none_in_shape:
args_2, kwargs_2 = nest.map_structure(
lambda x: convert_keras_tensor_to_numpy(x, fill_value=89),
(args, kwargs),
)
outputs_2 = fn(*args_2, **kwargs_2)

flat_out_1 = nest.flatten(outputs_1)
flat_out_2 = nest.flatten(outputs_2)

flat_out = []
for x1, x2 in zip(flat_out_1, flat_out_2):
shape = list(x1.shape)
for i, e in enumerate(x2.shape):
if e != shape[i]:
shape[i] = None
flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype)))
outputs = nest.pack_sequence_as(outputs_1, flat_out)

def convert_numpy_to_keras_tensor(x):
if isinstance(x, np.ndarray):
return KerasTensor(x.shape, x.dtype)
if is_tensor(x):
return KerasTensor(x.shape, standardize_dtype(x.dtype))
return x

return nest.map_structure(convert_numpy_to_keras_tensor, np_out)
output_spec = nest.map_structure(convert_numpy_to_keras_tensor, outputs)
return output_spec

0 comments on commit 82a1e5c

Please sign in to comment.