Skip to content

Commit

Permalink
fix torch error for bincount (#927)
Browse files Browse the repository at this point in the history
* fix torch error for bincount

* fix numpy tests

* add suppirt for 2D arrays

* update numpy test coverage

* update bincount implementation and reenable tests
  • Loading branch information
divyashreepathihalli authored Sep 20, 2023
1 parent a92593f commit 4c3697f
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 21 deletions.
19 changes: 15 additions & 4 deletions keras_core/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,21 @@ def add(x1, x2):

def bincount(x, weights=None, minlength=0):
if len(x.shape) == 2:
bincounts = [
jnp.bincount(arr, weights=weights, minlength=minlength)
for arr in list(x)
]
if weights is None:

def bincount_fn(arr):
return jnp.bincount(arr, minlength=minlength)

bincounts = list(map(bincount_fn, x))
else:

def bincount_fn(arr_w):
return jnp.bincount(
arr_w[0], weights=arr_w[1], minlength=minlength
)

bincounts = list(map(bincount_fn, zip(x, weights)))

return jnp.stack(bincounts)
return jnp.bincount(x, weights=weights, minlength=minlength)

Expand Down
17 changes: 17 additions & 0 deletions keras_core/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,23 @@ def average(x, axis=None, weights=None):


def bincount(x, weights=None, minlength=0):
if len(x.shape) == 2:
if weights is None:

def bincount_fn(arr):
return np.bincount(arr, minlength=minlength)

bincounts = list(map(bincount_fn, x))
else:

def bincount_fn(arr_w):
return np.bincount(
arr_w[0], weights=arr_w[1], minlength=minlength
)

bincounts = list(map(bincount_fn, zip(x, weights)))

return np.stack(bincounts)
return np.bincount(x, weights, minlength)


Expand Down
20 changes: 19 additions & 1 deletion keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,25 @@ def average(x, axis=None, weights=None):

def bincount(x, weights=None, minlength=0):
x = convert_to_tensor(x, dtype=int)
weights = convert_to_tensor(weights)
if weights is not None:
weights = convert_to_tensor(weights)
if len(x.shape) == 2:
if weights is None:

def bincount_fn(arr):
return torch.bincount(arr, minlength=minlength)

bincounts = list(map(bincount_fn, x))
else:

def bincount_fn(arr_w):
return torch.bincount(
arr_w[0], weights=arr_w[1], minlength=minlength
)

bincounts = list(map(bincount_fn, zip(x, weights)))

return torch.stack(bincounts)
return torch.bincount(x, weights, minlength)


Expand Down
1 change: 1 addition & 0 deletions keras_core/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def call(self, inputs):
output_mode=self.output_mode,
depth=len(self.bin_boundaries) + 1,
dtype=self.compute_dtype,
count_weights=None,
backend_module=self.backend,
)
return outputs
Expand Down
4 changes: 0 additions & 4 deletions keras_core/layers/preprocessing/discretization_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import numpy as np
import pytest
from tensorflow import data as tf_data

from keras_core import backend
Expand Down Expand Up @@ -36,9 +35,6 @@ def test_adapt_flow(self):
output = layer(np.array([[0.0, 0.1, 0.3]]))
self.assertTrue(output.dtype, "int32")

@pytest.mark.skipif(
backend.backend() in ("torch", "numpy"), reason="TODO: fix me"
)
def test_correctness(self):
# int mode
layer = layers.Discretization(
Expand Down
17 changes: 17 additions & 0 deletions keras_core/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2892,6 +2892,23 @@ def test_bincount(self):
knp.Bincount(weights=weights, minlength=minlength)(x),
np.bincount(x, weights=weights, minlength=minlength),
)
x = np.array([[1, 1, 2, 3, 2, 4, 4, 5]])
weights = np.array([[0, 0, 3, 2, 1, 1, 4, 2]])
expected_output = np.array([[0, 0, 4, 2, 5, 2]])
self.assertAllClose(
knp.bincount(x, weights=weights, minlength=minlength),
expected_output,
)
self.assertAllClose(
knp.Bincount(weights=weights, minlength=minlength)(x),
expected_output,
)
# test with weights=None
expected_output = np.array([[0, 2, 2, 1, 2, 1]])
self.assertAllClose(
knp.Bincount(weights=None, minlength=minlength)(x),
expected_output,
)

def test_broadcast_to(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
Expand Down
23 changes: 11 additions & 12 deletions keras_core/utils/numerical_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ def encode_categorical_inputs(
# In all cases, we should uprank scalar input to a single sample.
if len(backend_module.shape(inputs)) == 0:
inputs = backend_module.numpy.expand_dims(inputs, -1)
# One hot will unprank only if the final output dimension is not already 1.
if output_mode == "one_hot":
if backend_module.shape(inputs)[-1] != 1:
inputs = backend_module.numpy.expand_dims(inputs, -1)

if len(backend_module.shape(inputs)) > 2:
raise ValueError(
Expand All @@ -139,15 +135,18 @@ def encode_categorical_inputs(
)

binary_output = output_mode in ("multi_hot", "one_hot")
bincounts = backend_module.numpy.bincount(
inputs,
weights=count_weights,
minlength=depth,
)
if binary_output:
one_hot_input = backend_module.nn.one_hot(inputs, depth)
bincounts = backend_module.numpy.where(
backend_module.numpy.any(one_hot_input, axis=-2), 1, 0
if output_mode == "one_hot":
bincounts = backend_module.nn.one_hot(inputs, depth)
elif output_mode == "multi_hot":
one_hot_input = backend_module.nn.one_hot(inputs, depth)
bincounts = backend_module.numpy.where(
backend_module.numpy.any(one_hot_input, axis=-2), 1, 0
)
else:
bincounts = backend_module.numpy.bincount(
inputs,
minlength=depth,
)
bincounts = backend_module.cast(bincounts, dtype)

Expand Down

0 comments on commit 4c3697f

Please sign in to comment.