-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CDK: Add schema inferrer class (#20941)
* fix stuff * Update schema_inferrer.py * Update schema_inferrer.py * bump version * review comments * code style * fix formatting * improve tests
- Loading branch information
Joe Reuter
authored
Jan 6, 2023
1 parent
fc05f65
commit e2547ff
Showing
5 changed files
with
165 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
# | ||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from .schema_inferrer import SchemaInferrer | ||
from .traced_exception import AirbyteTracedException | ||
|
||
__all__ = ["AirbyteTracedException"] | ||
__all__ = ["AirbyteTracedException", "SchemaInferrer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# | ||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from collections import defaultdict | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from airbyte_cdk.models import AirbyteRecordMessage | ||
from genson import SchemaBuilder | ||
from genson.schema.strategies.object import Object | ||
|
||
|
||
class NoRequiredObj(Object): | ||
""" | ||
This class has Object behaviour, but it does not generate "required[]" fields | ||
every time it parses object. So we dont add unnecessary extra field. | ||
""" | ||
|
||
def to_schema(self): | ||
schema = super(NoRequiredObj, self).to_schema() | ||
schema.pop("required", None) | ||
return schema | ||
|
||
|
||
class NoRequiredSchemaBuilder(SchemaBuilder): | ||
EXTRA_STRATEGIES = (NoRequiredObj,) | ||
|
||
|
||
# This type is inferred from the genson lib, but there is no alias provided for it - creating it here for type safety | ||
InferredSchema = Dict[str, Union[str, Any, List, List[Dict[str, Union[Any, List]]]]] | ||
|
||
|
||
class SchemaInferrer: | ||
""" | ||
This class is used to infer a JSON schema which fits all the records passed into it | ||
throughout its lifecycle via the accumulate method. | ||
Instances of this class are stateful, meaning they build their inferred schemas | ||
from every record passed into the accumulate method. | ||
""" | ||
|
||
stream_to_builder: Dict[str, SchemaBuilder] | ||
|
||
def __init__(self): | ||
self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder) | ||
|
||
def accumulate(self, record: AirbyteRecordMessage): | ||
"""Uses the input record to add to the inferred schemas maintained by this object""" | ||
self.stream_to_builder[record.stream].add_object(record.data) | ||
|
||
def get_inferred_schemas(self) -> Dict[str, InferredSchema]: | ||
""" | ||
Returns the JSON schemas for all encountered streams inferred by inspecting all records | ||
passed via the accumulate method | ||
""" | ||
schemas = {} | ||
for stream_name, builder in self.stream_to_builder.items(): | ||
schemas[stream_name] = builder.to_schema() | ||
return schemas | ||
|
||
def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]: | ||
""" | ||
Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name. | ||
""" | ||
return self.stream_to_builder[stream_name].to_schema() if stream_name in self.stream_to_builder else None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
airbyte-cdk/python/unit_tests/utils/test_schema_inferrer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# | ||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from typing import List, Mapping | ||
|
||
import pytest | ||
from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage | ||
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer | ||
|
||
NOW = 1234567 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_records,expected_schemas", | ||
[ | ||
pytest.param( | ||
[ | ||
{"stream": "my_stream", "data": {"field_A": "abc"}}, | ||
{"stream": "my_stream", "data": {"field_A": "def"}}, | ||
], | ||
{"my_stream": {"field_A": {"type": "string"}}}, | ||
id="test_basic", | ||
), | ||
pytest.param( | ||
[ | ||
{"stream": "my_stream", "data": {"field_A": 1.0}}, | ||
{"stream": "my_stream", "data": {"field_A": "abc"}}, | ||
], | ||
{"my_stream": {"field_A": {"type": ["number", "string"]}}}, | ||
id="test_deriving_schema_refine", | ||
), | ||
pytest.param( | ||
[ | ||
{"stream": "my_stream", "data": {"obj": {"data": [1.0, 2.0, 3.0]}}}, | ||
{"stream": "my_stream", "data": {"obj": {"other_key": "xyz"}}}, | ||
], | ||
{ | ||
"my_stream": { | ||
"obj": { | ||
"type": "object", | ||
"properties": { | ||
"data": {"type": "array", "items": {"type": "number"}}, | ||
"other_key": {"type": "string"}, | ||
}, | ||
} | ||
} | ||
}, | ||
id="test_derive_schema_for_nested_structures", | ||
), | ||
], | ||
) | ||
def test_schema_derivation(input_records: List, expected_schemas: Mapping): | ||
inferrer = SchemaInferrer() | ||
for record in input_records: | ||
inferrer.accumulate(AirbyteRecordMessage(stream=record["stream"], data=record["data"], emitted_at=NOW)) | ||
|
||
for stream_name, expected_schema in expected_schemas.items(): | ||
assert inferrer.get_inferred_schemas()[stream_name] == { | ||
"$schema": "http://json-schema.org/schema#", | ||
"type": "object", | ||
"properties": expected_schema, | ||
} | ||
|
||
|
||
def test_deriving_schema_multiple_streams(): | ||
inferrer = SchemaInferrer() | ||
inferrer.accumulate(AirbyteRecordMessage(stream="my_stream", data={"field_A": 1.0}, emitted_at=NOW)) | ||
inferrer.accumulate(AirbyteRecordMessage(stream="my_stream2", data={"field_A": "abc"}, emitted_at=NOW)) | ||
inferred_schemas = inferrer.get_inferred_schemas() | ||
assert inferred_schemas["my_stream"] == { | ||
"$schema": "http://json-schema.org/schema#", | ||
"type": "object", | ||
"properties": {"field_A": {"type": "number"}}, | ||
} | ||
assert inferred_schemas["my_stream2"] == { | ||
"$schema": "http://json-schema.org/schema#", | ||
"type": "object", | ||
"properties": {"field_A": {"type": "string"}}, | ||
} | ||
|
||
|
||
def test_get_individual_schema(): | ||
inferrer = SchemaInferrer() | ||
inferrer.accumulate(AirbyteRecordMessage(stream="my_stream", data={"field_A": 1.0}, emitted_at=NOW)) | ||
assert inferrer.get_stream_schema("my_stream") == { | ||
"$schema": "http://json-schema.org/schema#", | ||
"type": "object", | ||
"properties": {"field_A": {"type": "number"}}, | ||
} | ||
assert inferrer.get_stream_schema("another_stream") is None |