Skip to content

Commit

Permalink
Fix for keras.metrics.Mean when the sum of weights is negative. (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
hertschuh authored Feb 9, 2024
1 parent 830bea6 commit 3cd7be1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
9 changes: 5 additions & 4 deletions keras/metrics/reduction_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ def reset_state(self):
self.count.assign(0)

def result(self):
return self.total / (
ops.maximum(
ops.cast(self.count, dtype=self.dtype), backend.epsilon()
)
count = ops.cast(self.count, dtype=self.dtype)
return (
ops.sign(count)
* self.total
/ ops.maximum(ops.abs(count), backend.epsilon())
)


Expand Down
6 changes: 6 additions & 0 deletions keras/metrics/reduction_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def test_weighted(self):
result = mean_obj.result()
self.assertAllClose(result, 2.0, atol=1e-3)

def test_weighted_negative_weigts(self):
mean_obj = reduction_metrics.Mean(name="mean", dtype="float32")
mean_obj.update_state([1, 3, 5, 7], sample_weight=[-1, -1, 0, 0])
result = mean_obj.result()
self.assertAllClose(result, 2.0, atol=1e-3)

def test_weighted_nd(self):
mean_obj = reduction_metrics.Mean(name="mean", dtype="float32")
mean_obj.update_state([[1, 3], [5, 7]], sample_weight=[[1, 1], [1, 0]])
Expand Down

0 comments on commit 3cd7be1

Please sign in to comment.