Skip to content

Commit

Permalink
docs: review docstrings in haystack.components.validators (#7238)
Browse files Browse the repository at this point in the history
* chore: make private

* docs: review and normalize docstrings

* docs: fix format and unused import
  • Loading branch information
wochinge authored Feb 28, 2024
1 parent c4b54bc commit e5f0e24
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 28 deletions.
58 changes: 32 additions & 26 deletions haystack/components/validators/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
@component
class JsonSchemaValidator:
"""
JsonSchemaValidator validates JSON content of ChatMessage against a specified JSON schema.
Validates JSON content of `ChatMessage` against a specified [JSON Schema](https://json-schema.org/).
If JSON content of a message conforms to the provided schema, the message is passed along the "validated" output.
If the JSON content does not conform to the schema, the message is passed along the "validation_error" output.
In the latter case, the error message is constructed using the provided error_template or a default template.
In the latter case, the error message is constructed using the provided `error_template` or a default template.
These error ChatMessages can be used by LLMs in Haystack 2.x recovery loops.
Here is a small example of how to use this component in a pipeline implementing schema validation recovery loop:
Usage example:
```python
from typing import List
Expand Down Expand Up @@ -52,15 +52,13 @@ def run(self, messages: List[ChatMessage]) -> dict:
p.connect("llm.replies", "schema_validator.messages")
p.connect("schema_validator.validation_error", "mx_for_llm")
result = p.run(
data={"message_producer": {"messages":[ChatMessage.from_user("Generate JSON for person with name 'John' and age 30")]},
"schema_validator": {"json_schema": {"type": "object",
"properties": {"name": {"type": "string"},
"age": {"type": "integer"}}}}})
print(result)
>> {'schema_validator': {'validated': [ChatMessage(content='\n{\n "name": "John",\n "age": 30\n}',
>> {'schema_validator': {'validated': [ChatMessage(content='\\n{\\n "name": "John",\\n "age": 30\\n}',
role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-1106-preview', 'index': 0,
'finish_reason': 'stop', 'usage': {'completion_tokens': 17, 'prompt_tokens': 20, 'total_tokens': 37}})]}}
```
Expand All @@ -79,9 +77,8 @@ def run(self, messages: List[ChatMessage]) -> dict:

def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: Optional[str] = None):
"""
Initializes a new JsonSchemaValidator instance.
:param json_schema: A dictionary representing the JSON schema against which the messages' content is validated.
:param json_schema: A dictionary representing the [JSON schema](https://json-schema.org/) against which
the messages' content is validated.
:param error_template: A custom template string for formatting the error message in case of validation failure.
"""
jsonschema_import.check()
Expand All @@ -94,16 +91,25 @@ def run(
messages: List[ChatMessage],
json_schema: Optional[Dict[str, Any]] = None,
error_template: Optional[str] = None,
):
) -> Dict[str, List[ChatMessage]]:
"""
Checks if the last message and its content field conforms to json_schema. If it does, the message is passed
along the "validated" output. If it does not, the message is passed along the "validation_error" output.
Validates the last of the provided messages against the specified json schema.
If it does, the message is passed along the "validated" output. If it does not, the message is passed along
the "validation_error" output.
:param messages: A list of ChatMessage instances to be validated. The last message in this list is the one
that is validated.
:param json_schema: A dictionary representing the JSON schema against which the messages' content is validated.
:param error_template: A custom template string for formatting the error message in case of validation
failure, by default None.
that is validated.
:param json_schema: A dictionary representing the [JSON schema](https://json-schema.org/)
against which the messages' content is validated. If not provided, the schema from the component init
is used.
:param error_template: A custom template string for formatting the error message in case of validation. If not
provided, the `error_template` from the component init is used.
:return: A dictionary with the following keys:
- "validated": A list of messages if the last message is valid.
- "validation_error": A list of messages if the last message is invalid.
:raises ValueError: If no JSON schema is provided or if the message content is not a dictionary or a list of
dictionaries.
"""
last_message = messages[-1]
last_message_content = json.loads(last_message.content)
Expand All @@ -116,8 +122,8 @@ def run(

# fc payload is json object but subtree `parameters` is string - we need to convert to json object
# we need complete json to validate it against schema
last_message_json = self.recursive_json_to_object(last_message_content)
using_openai_schema: bool = self.is_openai_function_calling_schema(json_schema)
last_message_json = self._recursive_json_to_object(last_message_content)
using_openai_schema: bool = self._is_openai_function_calling_schema(json_schema)
if using_openai_schema:
validation_schema = json_schema["parameters"]
else:
Expand All @@ -137,13 +143,13 @@ def run(

error_template = error_template or self.default_error_template

recovery_prompt = self.construct_error_recovery_message(
recovery_prompt = self._construct_error_recovery_message(
error_template, str(e), error_path, error_schema_path, validation_schema
)
complete_message_list = [ChatMessage.from_user(recovery_prompt)] + messages
return {"validation_error": complete_message_list}

def construct_error_recovery_message(
def _construct_error_recovery_message(
self,
error_template: str,
error_message: str,
Expand All @@ -169,16 +175,16 @@ def construct_error_recovery_message(
json_schema=json_schema,
)

def is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
def _is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
"""
Checks if the provided schema is a valid OpenAI function calling schema.
:param json_schema: The JSON schema to check
:return: True if the schema is a valid OpenAI function calling schema; otherwise, False.
:return: `True` if the schema is a valid OpenAI function calling schema; otherwise, `False`.
"""
return all(key in json_schema for key in ["name", "description", "parameters"])

def recursive_json_to_object(self, data: Any) -> Any:
def _recursive_json_to_object(self, data: Any) -> Any:
"""
Recursively traverses a data structure (dictionary or list), converting any string values
that are valid JSON objects into dictionary objects, and returns a new data structure.
Expand All @@ -187,7 +193,7 @@ def recursive_json_to_object(self, data: Any) -> Any:
:return: A new data structure with JSON strings converted to dictionary objects.
"""
if isinstance(data, list):
return [self.recursive_json_to_object(item) for item in data]
return [self._recursive_json_to_object(item) for item in data]

if isinstance(data, dict):
new_dict = {}
Expand All @@ -196,14 +202,14 @@ def recursive_json_to_object(self, data: Any) -> Any:
try:
json_value = json.loads(value)
new_dict[key] = (
self.recursive_json_to_object(json_value)
self._recursive_json_to_object(json_value)
if isinstance(json_value, (dict, list))
else json_value
)
except json.JSONDecodeError:
new_dict[key] = value
elif isinstance(value, dict):
new_dict[key] = self.recursive_json_to_object(value)
new_dict[key] = self._recursive_json_to_object(value)
else:
new_dict[key] = value
return new_dict
Expand Down
4 changes: 2 additions & 2 deletions test/components/validators/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_recursive_json_to_object(self, genuine_fc_message):

# but ensure_json_objects converts the string to a json object
validator = JsonSchemaValidator()
result = validator.recursive_json_to_object({"key": genuine_fc_message})
result = validator._recursive_json_to_object({"key": genuine_fc_message})

# we need this recursive json conversion to validate the message
assert result["key"][0]["function"]["arguments"]["basehead"] == "main...amzn_chat"
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_construct_custom_error_recovery_message(self):
"{json_schema}\n"
)

recovery_message = validator.construct_error_recovery_message(
recovery_message = validator._construct_error_recovery_message(
new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}
)

Expand Down

0 comments on commit e5f0e24

Please sign in to comment.