Skip to content

Commit

Permalink
Fix MetricCollection when input are metrics that return dicts with …
Browse files Browse the repository at this point in the history
…same keywords (#2027)

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
(cherry picked from commit 58fc9c6)
  • Loading branch information
SkafteNicki authored and Borda committed Aug 28, 2023
1 parent 379e13f commit 53c65d5
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 17 deletions.
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019)
- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019))


- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)
- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017))


- Fixed bug in `MetricCollection` when used with multiple metrics that return dicts with same keys ([#2027](https://github.com/Lightning-AI/torchmetrics/pull/2027))


- Fixed bug in detection intersection metrics when `class_metrics=True` resulting in wrong values ([#1924](https://github.com/Lightning-AI/torchmetrics/pull/1924))


- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028)
- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028))


## [1.1.0] - 2023-08-22
Expand Down
18 changes: 14 additions & 4 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import allclose
from torchmetrics.utilities.data import _flatten_dict, allclose
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

Expand Down Expand Up @@ -334,17 +334,27 @@ def _compute_and_reduce(
res = m(*args, **m._filter_kwargs(**kwargs))
else:
raise ValueError("method_name should be either 'compute' or 'forward', but got {method_name}")
result[k] = res

_, duplicates = _flatten_dict(result)

flattened_results = {}
for k, res in result.items():
if isinstance(res, dict):
for key, v in res.items():
# if duplicates of keys we need to add unique prefix to each key
if duplicates:
stripped_k = k.replace(getattr(m, "prefix", ""), "")
stripped_k = stripped_k.replace(getattr(m, "postfix", ""), "")
key = f"{stripped_k}_{key}"
if hasattr(m, "prefix") and m.prefix is not None:
key = f"{m.prefix}{key}"
if hasattr(m, "postfix") and m.postfix is not None:
key = f"{key}{m.postfix}"
result[key] = v
flattened_results[key] = v
else:
result[k] = res
return {self._set_name(k): v for k, v in result.items()}
flattened_results[k] = res
return {self._set_name(k): v for k, v in flattened_results.items()}

def reset(self) -> None:
"""Call reset for each metric sequentially."""
Expand Down
13 changes: 9 additions & 4 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
from lightning_utilities import apply_to_collection
Expand Down Expand Up @@ -60,16 +60,21 @@ def _flatten(x: Sequence) -> list:
return [item for sublist in x for item in sublist]


def _flatten_dict(x: Dict) -> Dict:
"""Flatten dict of dicts into single dict."""
def _flatten_dict(x: Dict) -> Tuple[Dict, bool]:
"""Flatten dict of dicts into single dict and checking for duplicates in keys along the way."""
new_dict = {}
duplicates = False
for key, value in x.items():
if isinstance(value, dict):
for k, v in value.items():
if k in new_dict:
duplicates = True
new_dict[k] = v
else:
if key in new_dict:
duplicates = True
new_dict[key] = value
return new_dict
return new_dict, duplicates


def to_onehot(
Expand Down
30 changes: 26 additions & 4 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,11 +614,33 @@ def test_nested_collections(input_collections):
assert "valmetrics/micro_MulticlassPrecision" in val


def test_double_nested_collections():
@pytest.mark.parametrize(
("base_metrics", "expected"),
[
(
DummyMetricMultiOutputDict(),
(
"prefix2_prefix1_output1_postfix1_postfix2",
"prefix2_prefix1_output2_postfix1_postfix2",
),
),
(
{"metric1": DummyMetricMultiOutputDict(), "metric2": DummyMetricMultiOutputDict()},
(
"prefix2_prefix1_metric1_output1_postfix1_postfix2",
"prefix2_prefix1_metric1_output2_postfix1_postfix2",
"prefix2_prefix1_metric2_output1_postfix1_postfix2",
"prefix2_prefix1_metric2_output2_postfix1_postfix2",
),
),
],
)
def test_double_nested_collections(base_metrics, expected):
"""Test that double nested collections gets flattened to a single collection."""
collection1 = MetricCollection([DummyMetricMultiOutputDict()], prefix="prefix1_", postfix="_postfix1")
collection1 = MetricCollection(base_metrics, prefix="prefix1_", postfix="_postfix1")
collection2 = MetricCollection([collection1], prefix="prefix2_", postfix="_postfix2")
x = torch.randn(10).sum()
val = collection2(x)
assert "prefix2_prefix1_output1_postfix1_postfix2" in val
assert "prefix2_prefix1_output2_postfix1_postfix2" in val

for key in val:
assert key in expected
5 changes: 3 additions & 2 deletions tests/unittests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def test_flatten_list():
def test_flatten_dict():
"""Check that _flatten_dict utility function works as expected."""
inp = {"a": {"b": 1, "c": 2}, "d": 3}
out = _flatten_dict(inp)
assert out == {"b": 1, "c": 2, "d": 3}
out_dict, out_dup = _flatten_dict(inp)
assert out_dict == {"b": 1, "c": 2, "d": 3}
assert out_dup is False


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu")
Expand Down

0 comments on commit 53c65d5

Please sign in to comment.