-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add calculate_metrics, MetricsResult, Exact Match * Add additional tests for metric calculation * Add release notes * Add docstring for Exact Match metric * Remove Exact Match Implementation * Update release notes * Remove unnecessary metrics implementation * Simplify logic to run supported metrics * Add some evaluation tests * Fix linting --------- Co-authored-by: Silvano Cerza <[email protected]> Co-authored-by: Silvano Cerza <[email protected]>
- Loading branch information
1 parent
e6d6ce1
commit 374a937
Showing
7 changed files
with
165 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from haystack.evaluation.eval import EvaluationResult, eval | ||
from haystack.evaluation.metrics import Metric, MetricsResult | ||
|
||
__all__ = ["eval", "EvaluationResult"] | ||
__all__ = ["eval", "EvaluationResult", "Metric", "MetricsResult"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import json | ||
from enum import Enum | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
|
||
class Metric(Enum): | ||
""" | ||
Contains a list of standard metrics available. | ||
""" | ||
|
||
RECALL = "Recall" | ||
MRR = "Mean Reciprocal Rank" | ||
MAP = "Mean Average Precision" | ||
F1 = "F1" | ||
EM = "Exact Match" | ||
SAS = "Semantic Answer Similarity" | ||
|
||
|
||
class MetricsResult(dict): | ||
""" | ||
Stores the metric values computed during the evaluation. | ||
""" | ||
|
||
def save(self, file: Union[str, Path]): | ||
""" | ||
Save the metrics stored in the MetricsResult to a json file. | ||
:param file: The file path or file name to save the data. | ||
""" | ||
with open(file, "w") as outfile: | ||
json.dump(self, outfile, indent=4) |
6 changes: 6 additions & 0 deletions
6
releasenotes/notes/add-calculate-metrics-metricsresults-03bf27ce8b16cff5.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
--- | ||
features: | ||
- | | ||
Adds `calculate_metrics()` function to EvaluationResult for computation of evaluation metrics. | ||
Adds `Metric` class to store list of available metrics. | ||
Adds `MetricsResult` class to store the metric values computed during the evaluation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from unittest.mock import MagicMock | ||
|
||
from haystack.core.pipeline import Pipeline | ||
from haystack.evaluation.eval import EvaluationResult | ||
from haystack.evaluation.metrics import Metric | ||
|
||
|
||
class TestEvaluationResult: | ||
def test_init(self): | ||
runnable = Pipeline() | ||
result = EvaluationResult(runnable=runnable, inputs=[], outputs=[], expected_outputs=[]) | ||
|
||
assert result.runnable == runnable | ||
assert result.inputs == [] | ||
assert result.outputs == [] | ||
assert result.expected_outputs == [] | ||
|
||
def test_supported_metrics_contains_all_metrics(self): | ||
runnable = Pipeline() | ||
result = EvaluationResult(runnable=runnable, inputs=[], outputs=[], expected_outputs=[]) | ||
|
||
supported_metrics = [m.name for m in result._supported_metrics.keys()] | ||
all_metric_names = [m.name for m in Metric] | ||
assert supported_metrics == all_metric_names | ||
|
||
def test_calculate_metrics_with_supported_metric(self): | ||
runnable = Pipeline() | ||
result = EvaluationResult(runnable=runnable, inputs=[], outputs=[], expected_outputs=[]) | ||
result._supported_metrics[Metric.RECALL] = MagicMock() | ||
result.calculate_metrics(metric=Metric.RECALL) | ||
|
||
assert result._supported_metrics[Metric.RECALL].called_once_with() | ||
|
||
def test_calculate_metrics_with_non_supported_metric(self): | ||
runnable = Pipeline() | ||
result = EvaluationResult(runnable=runnable, inputs=[], outputs=[], expected_outputs=[]) | ||
|
||
unsupported_metric = MagicMock() | ||
|
||
result.calculate_metrics(metric=unsupported_metric, some_argument="some_value") | ||
|
||
assert unsupported_metric.called_once_with(some_argument="some_value") |