Skip to content

Commit

Permalink
Wfh/json schema evaluation (langchain-ai#12389)
Browse files Browse the repository at this point in the history
Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
2 people authored and xieqihui committed Nov 21, 2023
1 parent 39a9a61 commit ff47efe
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 42 deletions.
95 changes: 93 additions & 2 deletions docs/docs/guides/evaluation/string/json.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"id": "7a8f3ec5-1cde-4b0e-80cd-ac0ac290d375",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -261,11 +261,102 @@
"print(result)"
]
},
{
"cell_type": "markdown",
"id": "6b15d18e-9b97-434f-905c-70acd4c35aea",
"metadata": {},
"source": [
"## JsonSchemaEvaluator\n",
"\n",
"The `JsonSchemaEvaluator` validates a JSON prediction against a provided JSON schema. If the prediction conforms to the schema, it returns a score of True (indicating no errors). Otherwise, it returns a score of 0 (indicating an error).\n",
"\n",
"### Overview:\n",
"- **Requires Input?**: Yes\n",
"- **Requires Reference?**: Yes (A JSON schema)\n",
"- **Score**: True (No errors) or False (Error occurred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "85afcf33-d2f4-406e-9d8f-15dc0a4772f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'score': True}\n"
]
}
],
"source": [
"from langchain.evaluation import JsonSchemaEvaluator\n",
"\n",
"evaluator = JsonSchemaEvaluator()\n",
"# Equivalently\n",
"# evaluator = load_evaluator(\"json_schema_validation\")\n",
"\n",
"result = evaluator.evaluate_strings(\n",
" prediction='{\"name\": \"John\", \"age\": 30}',\n",
" reference={\n",
" \"type\": \"object\",\n",
" \"properties\": {\"name\": {\"type\": \"string\"}, \"age\": {\"type\": \"integer\"}},\n",
" },\n",
")\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "bb5b89f6-0c87-4335-9091-55fd67a0565f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'score': True}\n"
]
}
],
"source": [
"result = evaluator.evaluate_strings(\n",
" prediction='{\"name\": \"John\", \"age\": 30}',\n",
" reference='{\"type\": \"object\", \"properties\": {\"name\": {\"type\": \"string\"}, \"age\": {\"type\": \"integer\"}}}',\n",
")\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "ff914d24-36bc-482a-a9ba-259cd0dd2a52",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'score': False, 'reasoning': \"<ValidationError: '30 is less than the minimum of 66'>\"}\n"
]
}
],
"source": [
"result = evaluator.evaluate_strings(\n",
" prediction='{\"name\": \"John\", \"age\": 30}',\n",
" reference='{\"type\": \"object\", \"properties\": {\"name\": {\"type\": \"string\"},'\n",
" '\"age\": {\"type\": \"integer\", \"minimum\": 66}}}',\n",
")\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b073f12d-4603-481c-8081-fab1af6bfcfe",
"metadata": {},
"outputs": [],
"source": []
}
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain/langchain/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
JsonValidityEvaluator,
)
from langchain.evaluation.parsing.json_distance import JsonEditDistanceEvaluator
from langchain.evaluation.parsing.json_schema import JsonSchemaEvaluator
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
from langchain.evaluation.regex_match.base import RegexMatchStringEvaluator
from langchain.evaluation.schema import (
Expand Down Expand Up @@ -122,4 +123,5 @@
"JsonValidityEvaluator",
"JsonEqualityEvaluator",
"JsonEditDistanceEvaluator",
"JsonSchemaEvaluator",
]
2 changes: 2 additions & 0 deletions libs/langchain/langchain/evaluation/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
JsonValidityEvaluator,
)
from langchain.evaluation.parsing.json_distance import JsonEditDistanceEvaluator
from langchain.evaluation.parsing.json_schema import JsonSchemaEvaluator
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
from langchain.evaluation.regex_match.base import RegexMatchStringEvaluator
from langchain.evaluation.schema import EvaluatorType, LLMEvalChain, StringEvaluator
Expand Down Expand Up @@ -88,6 +89,7 @@ def load_dataset(uri: str) -> List[Dict]:
EvaluatorType.JSON_VALIDITY: JsonValidityEvaluator,
EvaluatorType.JSON_EQUALITY: JsonEqualityEvaluator,
EvaluatorType.JSON_EDIT_DISTANCE: JsonEditDistanceEvaluator,
EvaluatorType.JSON_SCHEMA_VALIDATION: JsonSchemaEvaluator,
EvaluatorType.REGEX_MATCH: RegexMatchStringEvaluator,
EvaluatorType.EXACT_MATCH: ExactMatchStringEvaluator,
}
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/evaluation/parsing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _evaluate_strings(
prediction: str,
input: Optional[str] = None,
reference: Optional[str] = None,
**kwargs: Any
**kwargs: Any,
) -> dict:
"""Evaluate the prediction string.
Expand Down Expand Up @@ -134,7 +134,7 @@ def _evaluate_strings(
prediction: str,
input: Optional[str] = None,
reference: Optional[str] = None,
**kwargs: Any
**kwargs: Any,
) -> dict:
"""Evaluate the prediction string.
Expand Down
8 changes: 5 additions & 3 deletions libs/langchain/langchain/evaluation/parsing/json_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self,
string_distance: Optional[Callable[[str, str], float]] = None,
canonicalize: Optional[Callable[[Any], Any]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
super().__init__()
if string_distance is not None:
Expand All @@ -58,7 +58,9 @@ def __init__(
self._canonicalize = canonicalize
else:
self._canonicalize = lambda x: json.dumps(
x, separators=(",", ":"), sort_keys=True # eliminate whitespace
x,
separators=(",", ":"),
sort_keys=True, # eliminate whitespace
)

@property
Expand All @@ -83,7 +85,7 @@ def _evaluate_strings(
prediction: str,
input: Optional[str] = None,
reference: Optional[str] = None,
**kwargs: Any
**kwargs: Any,
) -> dict:
parsed = self._canonicalize(self._parse_json(prediction))
label = self._canonicalize(self._parse_json(reference))
Expand Down
95 changes: 95 additions & 0 deletions libs/langchain/langchain/evaluation/parsing/json_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Any, Union

from langchain.evaluation.schema import StringEvaluator
from langchain.output_parsers.json import parse_json_markdown


class JsonSchemaEvaluator(StringEvaluator):
"""An evaluator that validates a JSON prediction against a JSON schema reference.
This evaluator checks if a given JSON prediction conforms to the provided JSON schema.
If the prediction is valid, the score is True (no errors). Otherwise, the score is False (error occurred).
Attributes:
requires_input (bool): Whether the evaluator requires input.
requires_reference (bool): Whether the evaluator requires reference.
evaluation_name (str): The name of the evaluation.
Examples:
evaluator = JsonSchemaEvaluator()
result = evaluator.evaluate_strings(
prediction='{"name": "John", "age": 30}',
reference={
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
}
}
)
assert result["score"] is not None
""" # noqa: E501

def __init__(self, **kwargs: Any) -> None:
"""Initializes the JsonSchemaEvaluator.
Args:
**kwargs: Additional keyword arguments.
Raises:
ImportError: If the jsonschema package is not installed.
"""
super().__init__()
try:
import jsonschema # noqa: F401
except ImportError:
raise ImportError(
"The JsonSchemaEvaluator requires the jsonschema package."
" Please install it with `pip install jsonschema`."
)

@property
def requires_input(self) -> bool:
"""Returns whether the evaluator requires input."""
return False

@property
def requires_reference(self) -> bool:
"""Returns whether the evaluator requires reference."""
return True

@property
def evaluation_name(self) -> str:
"""Returns the name of the evaluation."""
return "json_schema_validation"

def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
if isinstance(node, str):
return parse_json_markdown(node)
elif hasattr(node, "schema") and callable(getattr(node, "schema")):
# Pydantic model
return getattr(node, "schema")()
return node

def _validate(self, prediction: Any, schema: Any) -> dict:
from jsonschema import ValidationError, validate # noqa: F401

try:
validate(instance=prediction, schema=schema)
return {
"score": True,
}
except ValidationError as e:
return {"score": False, "reasoning": repr(e)}

def _evaluate_strings(
self,
prediction: Union[str, Any],
input: Union[str, Any] = None,
reference: Union[str, Any] = None,
**kwargs: Any,
) -> dict:
parsed_prediction = self._parse_json(prediction)
schema = self._parse_json(reference)
return self._validate(parsed_prediction, schema)
16 changes: 9 additions & 7 deletions libs/langchain/langchain/evaluation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence, Tuple, Union
from warnings import warn

from langchain.chains.base import Chain
Expand Down Expand Up @@ -66,6 +66,8 @@ class EvaluatorType(str, Enum):
"""Check if a prediction is equal to a reference JSON."""
JSON_EDIT_DISTANCE = "json_edit_distance"
"""Compute the edit distance between two JSON strings after canonicalization."""
JSON_SCHEMA_VALIDATION = "json_schema_validation"
"""Check if a prediction is valid JSON according to a JSON schema."""


class LLMEvalChain(Chain):
Expand Down Expand Up @@ -144,9 +146,9 @@ def requires_reference(self) -> bool:
def _evaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
prediction: Union[str, Any],
reference: Optional[Union[str, Any]] = None,
input: Optional[Union[str, Any]] = None,
**kwargs: Any,
) -> dict:
"""Evaluate Chain or LLM output, based on optional input and label.
Expand All @@ -167,9 +169,9 @@ def _evaluate_strings(
async def _aevaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
prediction: Union[str, Any],
reference: Optional[Union[str, Any]] = None,
input: Optional[Union[str, Any]] = None,
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
Expand Down
Loading

0 comments on commit ff47efe

Please sign in to comment.