diff --git a/keras_core/backend/jax/numpy.py b/keras_core/backend/jax/numpy.py index eb3720a61..bd1a11c3d 100644 --- a/keras_core/backend/jax/numpy.py +++ b/keras_core/backend/jax/numpy.py @@ -13,10 +13,21 @@ def add(x1, x2): def bincount(x, weights=None, minlength=0): if len(x.shape) == 2: - bincounts = [ - jnp.bincount(arr, weights=weights, minlength=minlength) - for arr in list(x) - ] + if weights is None: + + def bincount_fn(arr): + return jnp.bincount(arr, minlength=minlength) + + bincounts = list(map(bincount_fn, x)) + else: + + def bincount_fn(arr_w): + return jnp.bincount( + arr_w[0], weights=arr_w[1], minlength=minlength + ) + + bincounts = list(map(bincount_fn, zip(x, weights))) + return jnp.stack(bincounts) return jnp.bincount(x, weights=weights, minlength=minlength) diff --git a/keras_core/backend/numpy/numpy.py b/keras_core/backend/numpy/numpy.py index 9b8dda11f..33f46d364 100644 --- a/keras_core/backend/numpy/numpy.py +++ b/keras_core/backend/numpy/numpy.py @@ -133,6 +133,23 @@ def average(x, axis=None, weights=None): def bincount(x, weights=None, minlength=0): + if len(x.shape) == 2: + if weights is None: + + def bincount_fn(arr): + return np.bincount(arr, minlength=minlength) + + bincounts = list(map(bincount_fn, x)) + else: + + def bincount_fn(arr_w): + return np.bincount( + arr_w[0], weights=arr_w[1], minlength=minlength + ) + + bincounts = list(map(bincount_fn, zip(x, weights))) + + return np.stack(bincounts) return np.bincount(x, weights, minlength) diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index 4986320f6..da444be70 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -245,7 +245,25 @@ def average(x, axis=None, weights=None): def bincount(x, weights=None, minlength=0): x = convert_to_tensor(x, dtype=int) - weights = convert_to_tensor(weights) + if weights is not None: + weights = convert_to_tensor(weights) + if len(x.shape) == 2: + if weights is None: + + def bincount_fn(arr): + return torch.bincount(arr, minlength=minlength) + + bincounts = list(map(bincount_fn, x)) + else: + + def bincount_fn(arr_w): + return torch.bincount( + arr_w[0], weights=arr_w[1], minlength=minlength + ) + + bincounts = list(map(bincount_fn, zip(x, weights))) + + return torch.stack(bincounts) return torch.bincount(x, weights, minlength) diff --git a/keras_core/layers/preprocessing/discretization.py b/keras_core/layers/preprocessing/discretization.py index 42c088c42..420e487a7 100644 --- a/keras_core/layers/preprocessing/discretization.py +++ b/keras_core/layers/preprocessing/discretization.py @@ -233,6 +233,7 @@ def call(self, inputs): output_mode=self.output_mode, depth=len(self.bin_boundaries) + 1, dtype=self.compute_dtype, + count_weights=None, backend_module=self.backend, ) return outputs diff --git a/keras_core/layers/preprocessing/discretization_test.py b/keras_core/layers/preprocessing/discretization_test.py index 1a4cfa6f4..3af625a1b 100644 --- a/keras_core/layers/preprocessing/discretization_test.py +++ b/keras_core/layers/preprocessing/discretization_test.py @@ -1,7 +1,6 @@ import os import numpy as np -import pytest from tensorflow import data as tf_data from keras_core import backend @@ -36,9 +35,6 @@ def test_adapt_flow(self): output = layer(np.array([[0.0, 0.1, 0.3]])) self.assertTrue(output.dtype, "int32") - @pytest.mark.skipif( - backend.backend() in ("torch", "numpy"), reason="TODO: fix me" - ) def test_correctness(self): # int mode layer = layers.Discretization( diff --git a/keras_core/ops/numpy_test.py b/keras_core/ops/numpy_test.py index 5041db14e..83e333563 100644 --- a/keras_core/ops/numpy_test.py +++ b/keras_core/ops/numpy_test.py @@ -2892,6 +2892,23 @@ def test_bincount(self): knp.Bincount(weights=weights, minlength=minlength)(x), np.bincount(x, weights=weights, minlength=minlength), ) + x = np.array([[1, 1, 2, 3, 2, 4, 4, 5]]) + weights = np.array([[0, 0, 3, 2, 1, 1, 4, 2]]) + expected_output = np.array([[0, 0, 4, 2, 5, 2]]) + self.assertAllClose( + knp.bincount(x, weights=weights, minlength=minlength), + expected_output, + ) + self.assertAllClose( + knp.Bincount(weights=weights, minlength=minlength)(x), + expected_output, + ) + # test with weights=None + expected_output = np.array([[0, 2, 2, 1, 2, 1]]) + self.assertAllClose( + knp.Bincount(weights=None, minlength=minlength)(x), + expected_output, + ) def test_broadcast_to(self): x = np.array([[1, 2, 3], [3, 2, 1]]) diff --git a/keras_core/utils/numerical_utils.py b/keras_core/utils/numerical_utils.py index 259d9262b..fa515f345 100644 --- a/keras_core/utils/numerical_utils.py +++ b/keras_core/utils/numerical_utils.py @@ -125,10 +125,6 @@ def encode_categorical_inputs( # In all cases, we should uprank scalar input to a single sample. if len(backend_module.shape(inputs)) == 0: inputs = backend_module.numpy.expand_dims(inputs, -1) - # One hot will unprank only if the final output dimension is not already 1. - if output_mode == "one_hot": - if backend_module.shape(inputs)[-1] != 1: - inputs = backend_module.numpy.expand_dims(inputs, -1) if len(backend_module.shape(inputs)) > 2: raise ValueError( @@ -139,15 +135,18 @@ def encode_categorical_inputs( ) binary_output = output_mode in ("multi_hot", "one_hot") - bincounts = backend_module.numpy.bincount( - inputs, - weights=count_weights, - minlength=depth, - ) if binary_output: - one_hot_input = backend_module.nn.one_hot(inputs, depth) - bincounts = backend_module.numpy.where( - backend_module.numpy.any(one_hot_input, axis=-2), 1, 0 + if output_mode == "one_hot": + bincounts = backend_module.nn.one_hot(inputs, depth) + elif output_mode == "multi_hot": + one_hot_input = backend_module.nn.one_hot(inputs, depth) + bincounts = backend_module.numpy.where( + backend_module.numpy.any(one_hot_input, axis=-2), 1, 0 + ) + else: + bincounts = backend_module.numpy.bincount( + inputs, + minlength=depth, ) bincounts = backend_module.cast(bincounts, dtype)