Skip to content

Commit

Permalink
add target_context as strings instead of list of strings
Browse files Browse the repository at this point in the history
  • Loading branch information
oyangz committed May 20, 2024
1 parent 48a7bf1 commit 079bc9b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 115 deletions.
71 changes: 22 additions & 49 deletions src/fmeval/data_loaders/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,12 @@ def _parse_column(args: ColumnParseArguments) -> Optional[Union[Any, List[Any]]]
return result

@staticmethod
def _validate_jmespath_result(result: Union[Any, List[Any], List[List[Any]]], args: ColumnParseArguments) -> None:
def _validate_jmespath_result(result: Union[Any, List[Any]], args: ColumnParseArguments) -> None:
"""Validates that the JMESPath result is as expected.
For dataset column TARGET_CONTEXT, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is
expected to be a 2D array. If MIME_TYPE_JSON_LINES, then `result` is expected to be a 1D array (list).
For all other dataset columns, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected to be
a 1D array (list). If MIME_TYPE_JSON_LINES, then `result` is expected to be a single scalar value.
If `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected
to be a 1D array (list). If MIME_TYPE_JSON_LINES, then `result` is expected
to be a single scalar value.
:param result: JMESPath query result to be validated.
:param args: See ColumnParseArguments docstring.
Expand All @@ -177,38 +175,24 @@ def _validate_jmespath_result(result: Union[Any, List[Any], List[List[Any]]], ar
f"the {args.column.value.name} column of dataset `{args.dataset_name}`, but found at least "
"one value that is None.",
)
if args.column.value.name == DatasetColumns.TARGET_CONTEXT.value.name:
require(
all(isinstance(x, list) for x in result),
f"Expected a 2D array using JMESPath '{args.jmespath_parser.expression}' on dataset "
f"`{args.dataset_name}` but found at least one non-list object.",
)
else:
require(
all(not isinstance(x, list) for x in result),
f"Expected a 1D array using JMESPath '{args.jmespath_parser.expression}' on dataset "
f"`{args.dataset_name}`, where each element of the array is a sample's {args.column.value.name}, "
f"but found at least one nested array.",
)
require(
all(not isinstance(x, list) for x in result),
f"Expected a 1D array using JMESPath '{args.jmespath_parser.expression}' on dataset "
f"`{args.dataset_name}`, where each element of the array is a sample's {args.column.value.name}, "
f"but found at least one nested array.",
)
elif args.dataset_mime_type == MIME_TYPE_JSONLINES:
require(
result is not None,
f"Found no values using {args.column.value.name} JMESPath '{args.jmespath_parser.expression}' "
f"on dataset `{args.dataset_name}`.",
)
if args.column.value.name == DatasetColumns.TARGET_CONTEXT.value.name:
require(
isinstance(result, list),
f"Expected to find a List using JMESPath '{args.jmespath_parser.expression}' on a dataset line in "
f"`{args.dataset_name}`, but found a non-list object instead.",
)
else:
require(
not isinstance(result, list),
f"Expected to find a single value using {args.column.value.name} JMESPath "
f"'{args.jmespath_parser.expression}' on a dataset line in "
f"dataset `{args.dataset_name}`, but found a list instead.",
)
require(
not isinstance(result, list),
f"Expected to find a single value using {args.column.value.name} JMESPath "
f"'{args.jmespath_parser.expression}' on a dataset line in "
f"dataset `{args.dataset_name}`, but found a list instead.",
)
else: # pragma: no cover
raise EvalAlgorithmInternalError(
f"args.dataset_mime_type is {args.dataset_mime_type}, but only JSON " "and JSON Lines are supported."
Expand All @@ -233,37 +217,26 @@ def _validate_parsed_columns_lengths(parsed_columns_dict: Dict[str, List[Any]]):
)

@staticmethod
def _cast_to_string(
result: Union[Any, List[Any], List[List[Any]]], args: ColumnParseArguments
) -> Union[str, List[str], List[List[str]]]:
def _cast_to_string(result: Union[Any, List[Any]], args: ColumnParseArguments) -> Union[str, List[str]]:
"""
Casts the contents of `result` to string(s), raising an error if casting fails.
It is extremely unlikely that the str() operation should fail; this basically
only happens if the object has explicitly overwritten the __str__ method to raise
an exception.
For dataset column TARGET_CONTEXT, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is
expected to be a 2D array. If MIME_TYPE_JSON_LINES, then `result` is expected to be a 1D array (list).
For all other dataset columns, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected to be
a 1D array (list). If MIME_TYPE_JSON_LINES, then `result` is expected to be a single scalar value.
If `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected
to be a 1D array (list) of objects. If MIME_TYPE_JSON_LINES, then `result`
is expected to be a single object.
:param result: JMESPath query result to be casted.
:param args: See ColumnParseArguments docstring.
:returns: `result` casted to a string or list of strings.
"""
try:
if args.dataset_mime_type == MIME_TYPE_JSON:
if args.column.value.name == DatasetColumns.TARGET_CONTEXT.value.name:
return [[str(x) for x in sample] for sample in result]
else:
return [str(x) for x in result]
return [str(x) for x in result]
elif args.dataset_mime_type == MIME_TYPE_JSONLINES:
return (
[str(x) for x in result]
if args.column.value.name == DatasetColumns.TARGET_CONTEXT.value.name
else str(result)
)
return str(result)
else:
raise EvalAlgorithmInternalError( # pragma: no cover
f"args.dataset_mime_type is {args.dataset_mime_type}, but only JSON and JSON Lines are supported."
Expand Down
61 changes: 20 additions & 41 deletions test/unit/data_loaders/test_json_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
CustomJSONDatasource,
)
from fmeval.data_loaders.util import DataConfig
from typing import Any, Dict, List, NamedTuple, Optional, Union
from typing import Any, Dict, List, NamedTuple, Optional
from fmeval.constants import (
DatasetColumns,
MIME_TYPE_JSON,
Expand Down Expand Up @@ -45,7 +45,7 @@ def create_temp_jsonlines_data_file_from_input_dataset(path: pathlib.Path, input

class TestJsonDataLoader:
class TestCaseReadDataset(NamedTuple):
input_dataset: Union[Dict[str, Any], List[Dict[str, Any]]]
input_dataset: Dict[str, Any]
expected_dataset: List[Dict[str, Any]]
dataset_mime_type: str
model_input_jmespath: Optional[str] = None
Expand All @@ -72,88 +72,67 @@ class TestCaseReadDataset(NamedTuple):
# containing heterogeneous lists.
TestCaseReadDataset(
input_dataset={
"row_1": ["a", True, False, 0],
"row_2": ["b", False, False, 1],
"row_3": ["c", False, True, 2],
"row_1": ["a", True, False, 0, "context_a"],
"row_2": ["b", False, False, 1, "context_b"],
"row_3": ["c", False, True, 2, "context_c"],
},
expected_dataset=[
{
DatasetColumns.MODEL_INPUT.value.name: "a",
DatasetColumns.MODEL_OUTPUT.value.name: "True",
DatasetColumns.TARGET_OUTPUT.value.name: "False",
DatasetColumns.CATEGORY.value.name: "0",
DatasetColumns.TARGET_CONTEXT.value.name: "context_a",
},
{
DatasetColumns.MODEL_INPUT.value.name: "b",
DatasetColumns.MODEL_OUTPUT.value.name: "False",
DatasetColumns.TARGET_OUTPUT.value.name: "False",
DatasetColumns.CATEGORY.value.name: "1",
DatasetColumns.TARGET_CONTEXT.value.name: "context_b",
},
{
DatasetColumns.MODEL_INPUT.value.name: "c",
DatasetColumns.MODEL_OUTPUT.value.name: "False",
DatasetColumns.TARGET_OUTPUT.value.name: "True",
DatasetColumns.CATEGORY.value.name: "2",
DatasetColumns.TARGET_CONTEXT.value.name: "context_c",
},
],
dataset_mime_type=MIME_TYPE_JSON,
model_input_jmespath="[row_1[0], row_2[0], row_3[0]]",
model_output_jmespath="[row_1[1], row_2[1], row_3[1]]",
target_output_jmespath="[row_1[2], row_2[2], row_3[2]]",
category_jmespath="[row_1[3], row_2[3], row_3[3]]",
),
TestCaseReadDataset(
input_dataset={
"model_input_col": ["a", "b", "c"],
"target_context": [["a", "b"], ["c", "d"], ["e", "f"]],
},
expected_dataset=[
{DatasetColumns.MODEL_INPUT.value.name: "a", DatasetColumns.TARGET_CONTEXT.value.name: ["a", "b"]},
{DatasetColumns.MODEL_INPUT.value.name: "b", DatasetColumns.TARGET_CONTEXT.value.name: ["c", "d"]},
{DatasetColumns.MODEL_INPUT.value.name: "c", DatasetColumns.TARGET_CONTEXT.value.name: ["e", "f"]},
],
dataset_mime_type=MIME_TYPE_JSON,
model_input_jmespath="model_input_col",
target_context_jmespath="target_context",
),
TestCaseReadDataset(
input_dataset=[
{"input": "a", "output": 3.14},
{"input": "c", "output": 2.718},
{"input": "e", "output": 1.00},
],
expected_dataset=[
{DatasetColumns.MODEL_INPUT.value.name: "a", DatasetColumns.MODEL_OUTPUT.value.name: "3.14"},
{DatasetColumns.MODEL_INPUT.value.name: "c", DatasetColumns.MODEL_OUTPUT.value.name: "2.718"},
{DatasetColumns.MODEL_INPUT.value.name: "e", DatasetColumns.MODEL_OUTPUT.value.name: "1.0"},
],
dataset_mime_type=MIME_TYPE_JSONLINES,
model_input_jmespath="input",
model_output_jmespath="output",
target_context_jmespath="[row_1[4], row_2[4], row_3[4]]",
),
TestCaseReadDataset(
input_dataset=[
{"input": "a", "target_context": ["context 1", "context 2"]},
{"input": "c", "target_context": ["context 3"]},
{"input": "e", "target_context": ["context 4"]},
{"input": "a", "output": 3.14, "context": "1"},
{"input": "c", "output": 2.718, "context": "2"},
{"input": "e", "output": 1.00, "context": "3"},
],
expected_dataset=[
{
DatasetColumns.MODEL_INPUT.value.name: "a",
DatasetColumns.TARGET_CONTEXT.value.name: ["context 1", "context 2"],
DatasetColumns.MODEL_OUTPUT.value.name: "3.14",
DatasetColumns.TARGET_CONTEXT.value.name: "1",
},
{
DatasetColumns.MODEL_INPUT.value.name: "c",
DatasetColumns.TARGET_CONTEXT.value.name: ["context 3"],
DatasetColumns.MODEL_OUTPUT.value.name: "2.718",
DatasetColumns.TARGET_CONTEXT.value.name: "2",
},
{
DatasetColumns.MODEL_INPUT.value.name: "e",
DatasetColumns.TARGET_CONTEXT.value.name: ["context 4"],
DatasetColumns.MODEL_OUTPUT.value.name: "1.0",
DatasetColumns.TARGET_CONTEXT.value.name: "3",
},
],
dataset_mime_type=MIME_TYPE_JSONLINES,
model_input_jmespath="input",
target_context_jmespath="target_context",
model_output_jmespath="output",
target_context_jmespath="context",
),
],
)
Expand Down
44 changes: 19 additions & 25 deletions test/unit/data_loaders/test_json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,30 +90,25 @@ def test_init_failure(self):
class TestCaseParseColumnFailure(NamedTuple):
result: List[Any]
error_message: str
column: DatasetColumns

@pytest.mark.parametrize(
"result, error_message, column",
"result, error_message",
[
TestCaseParseColumnFailure(
result="not a list",
error_message="Expected to find a non-empty list of samples",
column=DatasetColumns.MODEL_INPUT,
),
TestCaseParseColumnFailure(
result=[1, 2, None],
error_message="Expected an array of non-null values",
column=DatasetColumns.MODEL_INPUT,
),
TestCaseParseColumnFailure(
result=[1, 2, [3], 4], error_message="Expected a 1D array", column=DatasetColumns.MODEL_INPUT
),
TestCaseParseColumnFailure(
result=[[1], 2], error_message="Expected a 2D array", column=DatasetColumns.TARGET_CONTEXT
result=[1, 2, [3], 4],
error_message="Expected a 1D array",
),
],
)
def test_validation_failure_json(self, result, error_message, column):
def test_validation_failure_json(self, result, error_message):
"""
GIVEN a malformed `result` argument (obtained from a JSON dataset)
WHEN _validate_jmespath_result is called
Expand All @@ -123,28 +118,27 @@ def test_validation_failure_json(self, result, error_message, column):
with pytest.raises(EvalAlgorithmClientError, match=error_message):
args = ColumnParseArguments(
jmespath_parser=Mock(),
column=column,
column=Mock(),
dataset={},
dataset_mime_type=MIME_TYPE_JSON,
dataset_name="dataset",
)
JsonParser._validate_jmespath_result(result, args)

@pytest.mark.parametrize(
"result, error_message, column",
"result, error_message",
[
TestCaseParseColumnFailure(
result=None, error_message="Found no values using", column=DatasetColumns.MODEL_INPUT
),
TestCaseParseColumnFailure(
result=[1, 2, 3], error_message="Expected to find a single value", column=DatasetColumns.MODEL_INPUT
result=None,
error_message="Found no values using",
),
TestCaseParseColumnFailure(
result="Not a list", error_message="Expected to find a List", column=DatasetColumns.TARGET_CONTEXT
result=[1, 2, 3],
error_message="Expected to find a single value",
),
],
)
def test_validation_failure_jsonlines(self, result, error_message, column):
def test_validation_failure_jsonlines(self, result, error_message):
"""
GIVEN a malformed `result` argument (obtained from a JSON Lines dataset line)
WHEN _validate_jmespath_result is called
Expand All @@ -154,7 +148,7 @@ def test_validation_failure_jsonlines(self, result, error_message, column):
with pytest.raises(EvalAlgorithmClientError, match=error_message):
args = ColumnParseArguments(
jmespath_parser=Mock(),
column=column,
column=Mock(),
dataset={},
dataset_mime_type=MIME_TYPE_JSONLINES,
dataset_name="dataset",
Expand Down Expand Up @@ -191,7 +185,7 @@ class TestCaseJsonParseDatasetColumns(NamedTuple):
],
},
"model_output": {"sample_1": "positive", "sample_2": "negative"},
"target_context": [["a", "b"], ["c", "d"]],
"target_context": ["a", "b"],
"category": ["category_0", "category_1"],
},
),
Expand All @@ -215,14 +209,14 @@ class TestCaseJsonParseDatasetColumns(NamedTuple):
"model_output_col": "positive",
"target_output_col": "negative",
"category_col": "category_0",
"target_context": ["a", "b"],
"target_context": "a",
},
{
"model_input_col": "B",
"model_output_col": "negative",
"target_output_col": "positive",
"category_col": "category_1",
"target_context": ["c", "d"],
"target_context": "b",
},
],
),
Expand All @@ -232,7 +226,7 @@ class TestCaseJsonParseDatasetColumns(NamedTuple):
def test_json_parse_dataset_columns_success_json(self, mock_logger, config, dataset):
"""
GIVEN valid JMESPath queries that extract model inputs, model outputs,
target outputs, and categories, and a JSON dataset that is represented
target outputs, categories, and target context, and a JSON dataset that is represented
by either a dict or list
WHEN parse_dataset_columns is called
THEN parse_dataset_columns returns correct results
Expand All @@ -241,7 +235,7 @@ def test_json_parse_dataset_columns_success_json(self, mock_logger, config, data
expected_model_outputs = ["positive", "negative"]
expected_target_outputs = ["negative", "positive"]
expected_categories = ["category_0", "category_1"]
expected_target_context = [["a", "b"], ["c", "d"]]
expected_target_context = ["a", "b"]

parser = JsonParser(config)
cols = parser.parse_dataset_columns(dataset=dataset, dataset_mime_type=MIME_TYPE_JSON, dataset_name="dataset")
Expand Down Expand Up @@ -294,14 +288,14 @@ def test_parse_dataset_columns_success_jsonlines(self, mock_logger):
expected_model_output = "positive"
expected_target_output = "negative"
expected_category = "Red"
expected_target_context = ["context 1", "context 2"]
expected_target_context = "context"

dataset_line = {
"input": "A",
"output": "positive",
"target": "negative",
"category": "Red",
"target_context": ["context 1", "context 2"],
"target_context": "context",
}
cols = parser.parse_dataset_columns(
dataset=dataset_line, dataset_mime_type=MIME_TYPE_JSONLINES, dataset_name="dataset_line"
Expand Down

0 comments on commit 079bc9b

Please sign in to comment.