diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index d42f0f4..c31f9d1 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -172,6 +172,11 @@ def get_parser( if value_schema.anyOf: parsers = [get_parser(parsing_state, schema) for schema in value_schema.anyOf] return UnionParser(parsers) + if value_schema.allOf: + merged_schema = value_schema.allOf[0] + for schema in value_schema.allOf[1:]: + merged_schema = _merge_object_schemas(merged_schema, schema) + return get_parser(parsing_state, merged_schema) if value_schema.extras and 'const' in value_schema.extras: allowed_value = value_schema.extras['const'] is_string = type(allowed_value) == str diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index 3357809..78d6a8d 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -243,6 +243,53 @@ def test_any_json_object(): _test_json_schema_parsing_with_string('{"a": 1, "b": 2.2, "c": "c", "d": [1,2,3, null], "e": {"ee": 2}}', None, True) _test_json_schema_parsing_with_string("true", None, True) _test_json_schema_parsing_with_string('"str"', None, True) + + +def test_allof(): + # Define a schema that includes allOf + allof_schema = { + "type": "object", + "allOf": [ + { + "type": "object", + "properties": { + "num": { + "type": "number" + } + }, + "required": ["num"] + }, + { + "type": "object", + "properties": { + "str": { + "type": "string" + } + }, + "required": ["str"] + } + ] + } + + # Valid cases + valid_test_strings = [ + '{"num": 123, "str": "test"}', + '{"num": 0, "str": ""}' + ] + + # Invalid cases + invalid_test_strings = [ + '{"num": 123}', # Missing 'str' + '{"str": "test"}', # Missing 'num' + '{"num": "123", "str": "test"}', # Invalid type for 'num' + '{"num": 123, "str": 456}' # Invalid type for 'str' + ] + + for test_string in valid_test_strings: + _test_json_schema_parsing_with_string(test_string, allof_schema, True) + + for test_string in invalid_test_strings: + _test_json_schema_parsing_with_string(test_string, allof_schema, False) def test_long_json_object():