From 9c02e377562ecab3d2d2f3a10b0a7d8bd07a4a3f Mon Sep 17 00:00:00 2001 From: Emmanuel Evbuomwan Date: Tue, 4 Jun 2024 17:02:00 +0200 Subject: [PATCH] refactor: use new strongly typed `Version` --- karapace/dependency.py | 4 +- karapace/errors.py | 13 +- karapace/in_memory_database.py | 26 ++-- karapace/schema_models.py | 36 ++++- karapace/schema_reader.py | 4 +- karapace/schema_references.py | 12 +- karapace/schema_registry.py | 28 ++-- karapace/schema_registry_apis.py | 16 +- karapace/schema_versioning.py | 60 -------- karapace/serialization.py | 22 +-- karapace/typing.py | 28 +++- tests/unit/test_protobuf_serialization.py | 10 +- ...ma_versioning.py => test_schema_models.py} | 141 ++++++++++-------- tests/unit/test_serialization.py | 6 +- 14 files changed, 210 insertions(+), 196 deletions(-) delete mode 100644 karapace/schema_versioning.py rename tests/unit/{test_schema_versioning.py => test_schema_models.py} (66%) diff --git a/karapace/dependency.py b/karapace/dependency.py index 52b7e965e..074263af7 100644 --- a/karapace/dependency.py +++ b/karapace/dependency.py @@ -8,7 +8,7 @@ from __future__ import annotations from karapace.schema_references import Reference -from karapace.typing import JsonData, Subject +from karapace.typing import JsonData, Subject, Version from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -26,7 +26,7 @@ def __init__( self, name: str, subject: Subject, - version: int, + version: Version, target_schema: ValidatedTypedSchema, ) -> None: self.name = name diff --git a/karapace/errors.py b/karapace/errors.py index 7853f2de1..b5c3ced38 100644 --- a/karapace/errors.py +++ b/karapace/errors.py @@ -2,7 +2,14 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ -from karapace.schema_references import Referents + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from karapace.schema_references import Referents + from karapace.typing import Version class VersionNotFoundException(Exception): @@ -54,10 +61,10 @@ class SubjectNotSoftDeletedException(Exception): class ReferenceExistsException(Exception): - def __init__(self, referenced_by: Referents, version: int) -> None: + def __init__(self, referenced_by: Referents, version: Version) -> None: super().__init__() - self.version = version self.referenced_by = referenced_by + self.version = version class SubjectSoftDeletedException(Exception): diff --git a/karapace/in_memory_database.py b/karapace/in_memory_database.py index 3c7facc4c..81bf474d4 100644 --- a/karapace/in_memory_database.py +++ b/karapace/in_memory_database.py @@ -7,9 +7,9 @@ from __future__ import annotations from dataclasses import dataclass, field -from karapace.schema_models import SchemaVersion, TypedSchema +from karapace.schema_models import SchemaVersion, TypedSchema, Versioner from karapace.schema_references import Reference, Referents -from karapace.typing import SchemaId, Subject +from karapace.typing import SchemaId, Subject, Version from threading import Lock, RLock from typing import Iterable, Sequence @@ -20,7 +20,7 @@ @dataclass class SubjectData: - schemas: dict[int, SchemaVersion] = field(default_factory=dict) + schemas: dict[Version, SchemaVersion] = field(default_factory=dict) compatibility: str | None = None @@ -31,7 +31,7 @@ def __init__(self) -> None: self.subjects: dict[Subject, SubjectData] = {} self.schemas: dict[SchemaId, TypedSchema] = {} self.schema_lock_thread = RLock() - self.referenced_by: dict[tuple[Subject, int], Referents] = {} + self.referenced_by: dict[tuple[Subject, Version], Referents] = {} # Content based deduplication of schemas. This is used to reduce memory # usage when the same schema is produce multiple times to the same or @@ -100,15 +100,15 @@ def _delete_subject_from_schema_id_on_subject(self, *, subject: Subject) -> None def _get_from_hash_cache(self, *, typed_schema: TypedSchema) -> TypedSchema: return self._hash_to_schema.setdefault(typed_schema.fingerprint(), typed_schema) - def get_next_version(self, *, subject: Subject) -> int: - return max(self.subjects[subject].schemas) + 1 + def get_next_version(self, *, subject: Subject) -> Version: + return Versioner.V(max(self.subjects[subject].schemas) + 1) def insert_schema_version( self, *, subject: Subject, schema_id: SchemaId, - version: int, + version: Version, deleted: bool, schema: TypedSchema, references: Sequence[Reference] | None, @@ -217,19 +217,19 @@ def find_subjects(self, *, include_deleted: bool) -> list[Subject]: subject for subject in self.subjects if self.find_subject_schemas(subject=subject, include_deleted=False) ] - def find_subject_schemas(self, *, subject: Subject, include_deleted: bool) -> dict[int, SchemaVersion]: + def find_subject_schemas(self, *, subject: Subject, include_deleted: bool) -> dict[Version, SchemaVersion]: if subject not in self.subjects: return {} if include_deleted: return self.subjects[subject].schemas with self.schema_lock_thread: return { - version_id: schema_version + Versioner.V(version_id): schema_version for version_id, schema_version in self.subjects[subject].schemas.items() if schema_version.deleted is False } - def delete_subject(self, *, subject: Subject, version: int) -> None: + def delete_subject(self, *, subject: Subject, version: Version) -> None: with self.schema_lock_thread: for schema_version in self.subjects[subject].schemas.values(): if schema_version.version <= version: @@ -241,7 +241,7 @@ def delete_subject_hard(self, *, subject: Subject) -> None: del self.subjects[subject] self._delete_subject_from_schema_id_on_subject(subject=subject) - def delete_subject_schema(self, *, subject: Subject, version: int) -> None: + def delete_subject_schema(self, *, subject: Subject, version: Version) -> None: with self.schema_lock_thread: self.subjects[subject].schemas.pop(version, None) @@ -263,7 +263,7 @@ def num_schema_versions(self) -> tuple[int, int]: soft_deleted_versions += 1 return (live_versions, soft_deleted_versions) - def insert_referenced_by(self, *, subject: Subject, version: int, schema_id: SchemaId) -> None: + def insert_referenced_by(self, *, subject: Subject, version: Version, schema_id: SchemaId) -> None: with self.schema_lock_thread: referents = self.referenced_by.get((subject, version), None) if referents: @@ -271,7 +271,7 @@ def insert_referenced_by(self, *, subject: Subject, version: int, schema_id: Sch else: self.referenced_by[(subject, version)] = Referents([schema_id]) - def get_referenced_by(self, subject: Subject, version: int) -> Referents | None: + def get_referenced_by(self, subject: Subject, version: Version) -> Referents | None: with self.schema_lock_thread: return self.referenced_by.get((subject, version), None) diff --git a/karapace/schema_models.py b/karapace/schema_models.py index 86ccccbd9..d21917025 100644 --- a/karapace/schema_models.py +++ b/karapace/schema_models.py @@ -10,7 +10,7 @@ from jsonschema import Draft7Validator from jsonschema.exceptions import SchemaError from karapace.dependency import Dependency -from karapace.errors import InvalidSchema +from karapace.errors import InvalidSchema, InvalidVersion, VersionNotFoundException from karapace.protobuf.exception import ( Error as ProtobufError, IllegalArgumentException, @@ -23,7 +23,7 @@ from karapace.protobuf.schema import ProtobufSchema from karapace.schema_references import Reference from karapace.schema_type import SchemaType -from karapace.typing import JsonObject, SchemaId, Subject +from karapace.typing import JsonObject, SchemaId, Subject, Version, VersionTag from karapace.utils import assert_never, json_decode, json_encode, JSONDecodeError from typing import Any, cast, Dict, Final, final, Mapping, Sequence @@ -383,8 +383,38 @@ def parse( @dataclass class SchemaVersion: subject: Subject - version: int + version: Version deleted: bool schema_id: SchemaId schema: TypedSchema references: Sequence[Reference] | None + + +class Versioner: + @classmethod + def V(cls, tag: VersionTag) -> Version: + cls.validate_tag(tag=tag) + return Version(version=cls.resolve_tag(tag)) + + @classmethod + def validate_tag(cls, tag: VersionTag) -> None: + try: + version = cls.resolve_tag(tag=tag) + if (version < Version.MINUS_1_VERSION_TAG) or (version == 0): + raise InvalidVersion(f"Invalid version {tag}") + except ValueError as exc: + if tag != Version.LATEST_VERSION_TAG: + raise InvalidVersion(f"Invalid version {tag}") from exc + + @staticmethod + def resolve_tag(tag: VersionTag) -> int: + return Version.MINUS_1_VERSION_TAG if tag == Version.LATEST_VERSION_TAG else int(tag) + + @staticmethod + def from_schema_versions(schema_versions: Mapping[Version, SchemaVersion], version: Version) -> Version: + max_version = max(schema_versions) + if version.is_latest: + return max_version + if version in schema_versions and version <= max_version: + return version + raise VersionNotFoundException() diff --git a/karapace/schema_reader.py b/karapace/schema_reader.py index d863fbe9e..993a83368 100644 --- a/karapace/schema_reader.py +++ b/karapace/schema_reader.py @@ -39,7 +39,7 @@ from karapace.schema_models import parse_protobuf_schema_definition, SchemaType, TypedSchema, ValidatedTypedSchema from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping, Referents from karapace.statsd import StatsClient -from karapace.typing import JsonObject, SchemaId, Subject +from karapace.typing import JsonObject, SchemaId, Subject, Version from karapace.utils import json_decode, JSONDecodeError from threading import Event, Thread from typing import Final, Mapping, Sequence @@ -602,7 +602,7 @@ def remove_referenced_by( def get_referenced_by( self, subject: Subject, - version: int, + version: Version, ) -> Referents | None: return self.database.get_referenced_by(subject, version) diff --git a/karapace/schema_references.py b/karapace/schema_references.py index 746a583ba..9973b0ccb 100644 --- a/karapace/schema_references.py +++ b/karapace/schema_references.py @@ -8,7 +8,7 @@ from __future__ import annotations from karapace.dataclasses import default_dataclass -from karapace.typing import JsonData, JsonObject, SchemaId, Subject +from karapace.typing import JsonData, JsonObject, SchemaId, Subject, Version from typing import cast, List, Mapping, NewType, TypeVar Referents = NewType("Referents", List[SchemaId]) @@ -36,7 +36,7 @@ class LatestVersionReference: name: str subject: Subject - def resolve(self, version: int) -> Reference: + def resolve(self, version: Version) -> Reference: return Reference( name=self.name, subject=self.subject, @@ -48,10 +48,10 @@ def resolve(self, version: int) -> Reference: class Reference: name: str subject: Subject - version: int + version: Version def __post_init__(self) -> None: - assert self.version != -1 + assert self.version != Version.MINUS_1_VERSION_TAG def __repr__(self) -> str: return f"{{name='{self.name}', subject='{self.subject}', version={self.version}}}" @@ -68,7 +68,7 @@ def from_dict(data: JsonObject) -> Reference: return Reference( name=str(data["name"]), subject=Subject(str(data["subject"])), - version=int(cast(int, data["version"])), + version=Version(cast(int, data["version"])), ) @@ -88,6 +88,6 @@ def reference_from_mapping( else Reference( name=name, subject=subject, - version=int(version), + version=Version(version), ) ) diff --git a/karapace/schema_registry.py b/karapace/schema_registry.py index 8ab2ce26f..2ad5f3059 100644 --- a/karapace/schema_registry.py +++ b/karapace/schema_registry.py @@ -25,10 +25,10 @@ from karapace.master_coordinator import MasterCoordinator from karapace.messaging import KarapaceProducer from karapace.offset_watcher import OffsetWatcher -from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema, Version +from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema, Versioner from karapace.schema_reader import KafkaSchemaReader from karapace.schema_references import LatestVersionReference, Reference -from karapace.typing import JsonObject, Mode, SchemaId, Subject +from karapace.typing import JsonObject, Mode, SchemaId, Subject, Version from typing import Sequence import asyncio @@ -127,7 +127,7 @@ def schemas_get(self, schema_id: SchemaId, *, fetch_max_id: bool = False) -> Typ return schema - async def subject_delete_local(self, subject: Subject, permanent: bool) -> list[int]: + async def subject_delete_local(self, subject: Subject, permanent: bool) -> list[Version]: async with self.schema_lock: schema_versions = self.subject_get(subject, include_deleted=True) @@ -173,7 +173,7 @@ async def subject_delete_local(self, subject: Subject, permanent: bool) -> list[ try: schema_versions_live = self.subject_get(subject, include_deleted=False) except SchemasNotFoundException: - latest_version_id = int(0) + latest_version_id = Versioner.V(Version.MINUS_1_VERSION_TAG) version_list = [] else: version_list = list(schema_versions_live) @@ -187,7 +187,7 @@ async def subject_delete_local(self, subject: Subject, permanent: bool) -> list[ return version_list - async def subject_version_delete_local(self, subject: Subject, version: Version, permanent: bool) -> int: + async def subject_version_delete_local(self, subject: Subject, version: Version, permanent: bool) -> Version: async with self.schema_lock: schema_versions = self.subject_get(subject, include_deleted=True) if not permanent and version.is_latest: @@ -196,7 +196,7 @@ async def subject_version_delete_local(self, subject: Subject, version: Version, for version_id, schema_version in schema_versions.items() if schema_version.deleted is False } - resolved_version = version.resolve_from_schema_versions(schema_versions=schema_versions) + resolved_version = Versioner.from_schema_versions(schema_versions=schema_versions, version=version) schema_version = schema_versions.get(resolved_version, None) if not schema_version: @@ -224,7 +224,7 @@ async def subject_version_delete_local(self, subject: Subject, version: Version, self.schema_reader.remove_referenced_by(schema_version.schema_id, schema_version.references) return resolved_version - def subject_get(self, subject: Subject, include_deleted: bool = False) -> dict[int, SchemaVersion]: + def subject_get(self, subject: Subject, include_deleted: bool = False) -> dict[Version, SchemaVersion]: subject_found = self.database.find_subject(subject=subject) if not subject_found: raise SubjectNotFoundException() @@ -238,7 +238,7 @@ def subject_version_get(self, subject: Subject, version: Version, *, include_del schema_versions = self.subject_get(subject, include_deleted=include_deleted) if not schema_versions: raise SubjectNotFoundException() - resolved_version = version.resolve_from_schema_versions(schema_versions=schema_versions) + resolved_version = Versioner.from_schema_versions(schema_versions=schema_versions, version=version) schema_data: SchemaVersion | None = schema_versions.get(resolved_version, None) if not schema_data: @@ -269,7 +269,7 @@ async def subject_version_referencedby_get( schema_versions = self.subject_get(subject, include_deleted=include_deleted) if not schema_versions: raise SubjectNotFoundException() - resolved_version = version.resolve_from_schema_versions(schema_versions=schema_versions) + resolved_version = Versioner.from_schema_versions(schema_versions=schema_versions, version=version) schema_data: SchemaVersion | None = schema_versions.get(resolved_version, None) if not schema_data: raise VersionNotFoundException() @@ -311,7 +311,7 @@ async def write_new_schema_local( all_schema_versions = self.database.find_subject_schemas(subject=subject, include_deleted=True) if not all_schema_versions: - version = int(1) + version = Version(1) schema_id = self.database.get_schema_id(new_schema) LOG.debug( "Registering new subject: %r, id: %r with version: %r with schema %r, schema_id: %r", @@ -400,8 +400,8 @@ async def write_new_schema_local( def get_subject_versions_for_schema( self, schema_id: SchemaId, *, include_deleted: bool = False - ) -> list[dict[str, Subject | int]]: - subject_versions: list[dict[str, Subject | int]] = [] + ) -> list[dict[str, Subject | Version]]: + subject_versions: list[dict[str, Subject | Version]] = [] schema_versions = self.database.find_schema_versions_by_schema_id( schema_id=schema_id, include_deleted=include_deleted, @@ -423,7 +423,7 @@ def send_schema_message( subject: Subject, schema: TypedSchema | None, schema_id: int, - version: int, + version: Version, deleted: bool, references: Sequence[Reference] | None, ) -> None: @@ -459,7 +459,7 @@ def resolve_references( ) -> tuple[Sequence[Reference], dict[str, Dependency]] | tuple[None, None]: return self.schema_reader.resolve_references(references) if references else (None, None) - def send_delete_subject_message(self, subject: Subject, version: int) -> None: + def send_delete_subject_message(self, subject: Subject, version: Version) -> None: key = {"subject": subject, "magic": 0, "keytype": "DELETE_SUBJECT"} value = {"subject": subject, "version": version} self.producer.send_message(key=key, value=value) diff --git a/karapace/schema_registry_apis.py b/karapace/schema_registry_apis.py index 4d8fd1f7b..99d942bdd 100644 --- a/karapace/schema_registry_apis.py +++ b/karapace/schema_registry_apis.py @@ -31,10 +31,10 @@ from karapace.karapace import KarapaceBase from karapace.protobuf.exception import ProtobufUnresolvedDependencyException from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE, SERVER_NAME -from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema, Version +from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema, Versioner from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping from karapace.schema_registry import KarapaceSchemaRegistry -from karapace.typing import JsonData, JsonObject, SchemaId, Subject +from karapace.typing import JsonData, JsonObject, SchemaId, Subject, Version from karapace.utils import JSONDecodeError from typing import Any @@ -333,7 +333,7 @@ async def close(self) -> None: if self._auth is not None: stack.push_async_callback(self._auth.close) - def _subject_get(self, subject: str, content_type: str, include_deleted: bool = False) -> dict[int, SchemaVersion]: + def _subject_get(self, subject: str, content_type: str, include_deleted: bool = False) -> dict[Version, SchemaVersion]: try: schema_versions = self.schema_registry.subject_get(subject, include_deleted) except SubjectNotFoundException: @@ -381,7 +381,7 @@ async def compatibility_check( schema_type = self._validate_schema_type(content_type=content_type, data=body) references = self._validate_references(content_type, schema_type, body) try: - version = Version(version) + version = Versioner.V(version) references, new_schema_dependencies = self.schema_registry.resolve_references(references) new_schema = ValidatedTypedSchema.parse( schema_type=schema_type, @@ -784,7 +784,7 @@ async def subject_version_get( deleted = request.query.get("deleted", "false").lower() == "true" try: - version = Version(version) + version = Versioner.V(version) subject_data = self.schema_registry.subject_version_get(subject, version, include_deleted=deleted) if "compatibility" in subject_data: del subject_data["compatibility"] @@ -819,7 +819,7 @@ async def subject_version_delete( are_we_master, master_url = await self.schema_registry.get_master() if are_we_master: try: - version = Version(version) + version = Versioner.V(version) resolved_version = await self.schema_registry.subject_version_delete_local(subject, version, permanent) self.r(str(resolved_version), content_type, status=HTTPStatus.OK) except (SubjectNotFoundException, SchemasNotFoundException): @@ -890,7 +890,7 @@ async def subject_version_schema_get( self._check_authorization(user, Operation.Read, f"Subject:{subject}") try: - version = Version(version) + version = Versioner.V(version) subject_data = self.schema_registry.subject_version_get(subject, version) self.r(subject_data["schema"], content_type) except InvalidVersion: @@ -918,7 +918,7 @@ async def subject_version_referencedby_get(self, content_type, *, subject, versi self._check_authorization(user, Operation.Read, f"Subject:{subject}") try: - version = Version(version) + version = Versioner.V(version) referenced_by = await self.schema_registry.subject_version_referencedby_get(subject, version) except (SubjectNotFoundException, SchemasNotFoundException): self.r( diff --git a/karapace/schema_versioning.py b/karapace/schema_versioning.py deleted file mode 100644 index 265fb25aa..000000000 --- a/karapace/schema_versioning.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Copyright (c) 2023 Aiven Ltd -See LICENSE for details -""" -from __future__ import annotations - -from karapace.errors import InvalidVersion, VersionNotFoundException -from karapace.schema_models import SchemaVersion -from typing import ClassVar, Mapping, Union - -VersionTag = Union[str, int] - - -class Version(int): - LATEST_VERSION_TAG: ClassVar[str] = "latest" - MINUS_1_VERSION_TAG: ClassVar[int] = -1 - - @property - def is_latest(self) -> bool: - return self == self.MINUS_1_VERSION_TAG - - def from_schema_versions(self, schema_versions: Mapping[Version, SchemaVersion]) -> Version: - max_version = max(schema_versions) - if self.is_latest: - return max_version - if self <= max_version and self in schema_versions: - return self - raise VersionNotFoundException() - - @classmethod - def resolve_tag(cls, tag: VersionTag) -> int: - return cls.MINUS_1_VERSION_TAG if tag == cls.LATEST_VERSION_TAG else int(tag) - - @classmethod - def V(cls, tag: VersionTag) -> Version: - cls.validate_tag(tag=tag) - return Version(version=Version.resolve_tag(tag)) - - @classmethod - def validate_tag(cls, tag: VersionTag) -> None: - try: - version = cls.resolve_tag(tag=tag) - if (version < cls.MINUS_1_VERSION_TAG) or (version == 0): - raise InvalidVersion(f"Invalid version {tag}") - except ValueError as exc: - if tag != cls.LATEST_VERSION_TAG: - raise InvalidVersion(f"Invalid version {tag}") from exc - - def __new__(cls, version: int) -> Version: - if not isinstance(version, int): - raise InvalidVersion(f"Invalid version {version}") - if (version < cls.MINUS_1_VERSION_TAG) or (version == 0): - raise InvalidVersion(f"Invalid version {version}") - return super().__new__(cls, version) - - def __str__(self) -> str: - return f"{int(self)}" - - def __repr__(self) -> str: - return f"Version={int(self)}" diff --git a/karapace/serialization.py b/karapace/serialization.py index 8797dd780..81c51cabc 100644 --- a/karapace/serialization.py +++ b/karapace/serialization.py @@ -16,9 +16,9 @@ from karapace.protobuf.exception import ProtobufTypeException from karapace.protobuf.io import ProtobufDatumReader, ProtobufDatumWriter from karapace.protobuf.schema import ProtobufSchema -from karapace.schema_models import InvalidSchema, ParsedTypedSchema, SchemaType, TypedSchema, ValidatedTypedSchema +from karapace.schema_models import InvalidSchema, ParsedTypedSchema, SchemaType, TypedSchema, ValidatedTypedSchema, Versioner from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping -from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType +from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType, Version from karapace.utils import json_decode, json_encode from typing import Any, Callable, MutableMapping from urllib.parse import quote @@ -131,9 +131,9 @@ async def post_new_schema( async def _get_schema_recursive( self, subject: Subject, - explored_schemas: set[tuple[Subject, int | None]], - version: int | None = None, - ) -> tuple[SchemaId, ValidatedTypedSchema, int]: + explored_schemas: set[tuple[Subject, Version | None]], + version: Version | None = None, + ) -> tuple[SchemaId, ValidatedTypedSchema, Version]: if (subject, version) in explored_schemas: raise InvalidSchema( f"The schema has at least a cycle in dependencies, " @@ -174,7 +174,7 @@ async def _get_schema_recursive( references=references, dependencies=dependencies, ), - int(json_result["version"]), + Versioner.V(json_result["version"]), ) except InvalidSchema as e: raise SchemaRetrievalError(f"Failed to parse schema string from response: {json_result}") from e @@ -183,21 +183,21 @@ async def _get_schema_recursive( async def get_schema( self, subject: Subject, - version: int | None = None, - ) -> tuple[SchemaId, ValidatedTypedSchema, int]: + version: Version | None = None, + ) -> tuple[SchemaId, ValidatedTypedSchema, Version]: """ Retrieves the schema and its dependencies for the specified subject. Args: subject (Subject): The subject for which to retrieve the schema. - version (Optional[int]): The specific version of the schema to retrieve. + version (Optional[Version]): The specific version of the schema to retrieve. If None, the latest available schema will be returned. Returns: - Tuple[SchemaId, ValidatedTypedSchema, int]: A tuple containing: + Tuple[SchemaId, ValidatedTypedSchema, Version]: A tuple containing: - SchemaId: The ID of the retrieved schema. - ValidatedTypedSchema: The retrieved schema, validated and typed. - - int: The version of the schema that was retrieved. + - Version: The version of the schema that was retrieved. """ return await self._get_schema_recursive(subject, set(), version) diff --git a/karapace/typing.py b/karapace/typing.py index 922601056..40b29fa2d 100644 --- a/karapace/typing.py +++ b/karapace/typing.py @@ -2,8 +2,11 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from __future__ import annotations + from enum import Enum, unique -from typing import Dict, List, Mapping, NewType, Sequence, Union +from karapace.errors import InvalidVersion +from typing import ClassVar, Dict, List, Mapping, NewType, Sequence, Union from typing_extensions import TypeAlias JsonArray: TypeAlias = List["JsonData"] @@ -17,6 +20,7 @@ ArgJsonData: TypeAlias = Union[JsonScalar, ArgJsonObject, ArgJsonArray] Subject = NewType("Subject", str) +VersionTag = Union[str, int] # note: the SchemaID is a unique id among all the schemas (and each version should be assigned to a different id) # basically the same SchemaID refer always to the same TypedSchema. @@ -53,3 +57,25 @@ class SubjectType(StrEnum, Enum): @unique class Mode(StrEnum): readwrite = "READWRITE" + + +class Version(int): + LATEST_VERSION_TAG: ClassVar[str] = "latest" + MINUS_1_VERSION_TAG: ClassVar[int] = -1 + + def __new__(cls, version: int) -> Version: + if not isinstance(version, int): + raise InvalidVersion(f"Invalid version {version}") + if (version < cls.MINUS_1_VERSION_TAG) or (version == 0): + raise InvalidVersion(f"Invalid version {version}") + return super().__new__(cls, version) + + def __str__(self) -> str: + return f"{int(self)}" + + def __repr__(self) -> str: + return f"Version={int(self)}" + + @property + def is_latest(self) -> bool: + return self == self.MINUS_1_VERSION_TAG diff --git a/tests/unit/test_protobuf_serialization.py b/tests/unit/test_protobuf_serialization.py index b7b84839e..ee2586d63 100644 --- a/tests/unit/test_protobuf_serialization.py +++ b/tests/unit/test_protobuf_serialization.py @@ -5,7 +5,7 @@ from karapace.config import read_config from karapace.dependency import Dependency from karapace.protobuf.kotlin_wrapper import trim_margin -from karapace.schema_models import ParsedTypedSchema, SchemaType +from karapace.schema_models import ParsedTypedSchema, SchemaType, Versioner from karapace.schema_references import Reference from karapace.serialization import ( InvalidMessageHeader, @@ -45,7 +45,7 @@ async def test_happy_flow(default_config_path: Path): mock_protobuf_registry_client.get_schema_for_id.return_value = schema_for_id_one_future get_latest_schema_future = asyncio.Future() get_latest_schema_future.set_result( - (1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)), int(1)) + (1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)), Versioner.V(1)) ) mock_protobuf_registry_client.get_schema.return_value = get_latest_schema_future @@ -114,7 +114,7 @@ async def test_happy_flow_references(default_config_path: Path): schema_for_id_one_future.set_result((ref_schema, [Subject("stub")])) mock_protobuf_registry_client.get_schema_for_id.return_value = schema_for_id_one_future get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ref_schema, int(1))) + get_latest_schema_future.set_result((1, ref_schema, Versioner.V(1))) mock_protobuf_registry_client.get_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) @@ -201,7 +201,7 @@ async def test_happy_flow_references_two(default_config_path: Path): schema_for_id_one_future.set_result((ref_schema_two, [Subject("mock")])) mock_protobuf_registry_client.get_schema_for_id.return_value = schema_for_id_one_future get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ref_schema_two, int(1))) + get_latest_schema_future.set_result((1, ref_schema_two, Versioner.V(1))) mock_protobuf_registry_client.get_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) @@ -221,7 +221,7 @@ async def test_serialization_fails(default_config_path: Path): mock_protobuf_registry_client = Mock() get_latest_schema_future = asyncio.Future() get_latest_schema_future.set_result( - (1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)), int(1)) + (1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)), Versioner.V(1)) ) mock_protobuf_registry_client.get_schema.return_value = get_latest_schema_future diff --git a/tests/unit/test_schema_versioning.py b/tests/unit/test_schema_models.py similarity index 66% rename from tests/unit/test_schema_versioning.py rename to tests/unit/test_schema_models.py index a59c6a3df..fc1590da9 100644 --- a/tests/unit/test_schema_versioning.py +++ b/tests/unit/test_schema_models.py @@ -7,9 +7,9 @@ from avro.schema import Schema as AvroSchema from karapace.errors import InvalidVersion, VersionNotFoundException -from karapace.schema_models import parse_avro_schema_definition, SchemaVersion, TypedSchema +from karapace.schema_models import parse_avro_schema_definition, SchemaVersion, TypedSchema, Versioner from karapace.schema_type import SchemaType -from karapace.schema_versioning import Version, VersionTag +from karapace.typing import Version, VersionTag from typing import Any, Callable, Dict, Optional import operator @@ -22,39 +22,7 @@ class TestVersion: @pytest.fixture def version(self): - return Version(1) - - @pytest.fixture - def avro_schema(self) -> str: - return '{"type":"record","name":"testRecord","fields":[{"type":"string","name":"test"}]}' - - @pytest.fixture - def avro_schema_parsed(self, avro_schema: str) -> AvroSchema: - return parse_avro_schema_definition(avro_schema) - - @pytest.fixture - def schema_versions_factory( - self, - avro_schema: str, - avro_schema_parsed: AvroSchema, - ) -> Callable[[Version, Dict[str, Any]], Dict[Version, SchemaVersion]]: - def schema_versions(version: Version, schema_version_data: Optional[Dict[str, Any]] = None): - schema_version_data = schema_version_data or dict() - base_schema_version_data = dict( - subject="test-topic", - version=version, - deleted=False, - schema_id=1, - schema=TypedSchema( - schema_type=SchemaType.AVRO, - schema_str=avro_schema, - schema=avro_schema_parsed, - ), - references=None, - ) - return {version: SchemaVersion(**{**base_schema_version_data, **schema_version_data})} - - return schema_versions + return Versioner.V(1) def test_version(self, version: Version): assert version == 1 @@ -68,19 +36,15 @@ def test_tags(self, version: Version): @pytest.mark.parametrize("invalid_version", ["string", -10, 0]) def test_invalid_version(self, invalid_version: VersionTag): with pytest.raises(InvalidVersion): - Version(invalid_version) + Versioner.V(invalid_version) @pytest.mark.parametrize( "version, is_latest", - [(Version(-1), True), (Version(1), False)], + [(Versioner.V(-1), True), (Versioner.V(1), False)], ) def test_is_latest(self, version: Version, is_latest: bool): assert version.is_latest is is_latest - @pytest.mark.parametrize("tag, resolved", [("latest", -1), (10, 10), ("20", 20)]) - def test_resolve_tag(self, tag: VersionTag, resolved: int): - assert Version.resolve_tag(tag=tag) == resolved - def test_text_formating(self, version: Version): assert f"{version}" == "1" assert f"{version!r}" == "Version=1" @@ -88,14 +52,14 @@ def test_text_formating(self, version: Version): @pytest.mark.parametrize( "version, to_compare, comparer, valid", [ - (Version(1), Version(1), operator.eq, True), - (Version(1), Version(2), operator.eq, False), - (Version(2), Version(1), operator.gt, True), - (Version(2), Version(1), operator.lt, False), - (Version(2), Version(2), operator.ge, True), - (Version(2), Version(1), operator.ge, True), - (Version(1), Version(1), operator.le, True), - (Version(1), Version(2), operator.le, True), + (Versioner.V(1), Versioner.V(1), operator.eq, True), + (Versioner.V(1), Versioner.V(2), operator.eq, False), + (Versioner.V(2), Versioner.V(1), operator.gt, True), + (Versioner.V(2), Versioner.V(1), operator.lt, False), + (Versioner.V(2), Versioner.V(2), operator.ge, True), + (Versioner.V(2), Versioner.V(1), operator.ge, True), + (Versioner.V(1), Versioner.V(1), operator.le, True), + (Versioner.V(1), Versioner.V(2), operator.le, True), ], ) def test_comparisons( @@ -107,12 +71,50 @@ def test_comparisons( ): assert comparer(version, to_compare) is valid + +class TestVersioner: + @pytest.fixture + def avro_schema(self) -> str: + return '{"type":"record","name":"testRecord","fields":[{"type":"string","name":"test"}]}' + + @pytest.fixture + def avro_schema_parsed(self, avro_schema: str) -> AvroSchema: + return parse_avro_schema_definition(avro_schema) + + @pytest.fixture + def schema_versions_factory( + self, + avro_schema: str, + avro_schema_parsed: AvroSchema, + ) -> Callable[[Version, Dict[str, Any]], Dict[Version, SchemaVersion]]: + def schema_versions(version: Version, schema_version_data: Optional[Dict[str, Any]] = None): + schema_version_data = schema_version_data or dict() + base_schema_version_data = dict( + subject="test-topic", + version=version, + deleted=False, + schema_id=1, + schema=TypedSchema( + schema_type=SchemaType.AVRO, + schema_str=avro_schema, + schema=avro_schema_parsed, + ), + references=None, + ) + return {version: SchemaVersion(**{**base_schema_version_data, **schema_version_data})} + + return schema_versions + + @pytest.mark.parametrize("tag, resolved", [("latest", -1), (10, 10), ("20", 20)]) + def test_resolve_tag(self, tag: VersionTag, resolved: int): + assert Versioner.resolve_tag(tag=tag) == resolved + @pytest.mark.parametrize( "version, resolved_version", [ - (Version(-1), Version(10)), - (Version(1), Version(1)), - (Version(10), Version(10)), + (Versioner.V(-1), Versioner.V(10)), + (Versioner.V(1), Versioner.V(1)), + (Versioner.V(10), Versioner.V(10)), ], ) def test_from_schema_versions( @@ -122,33 +124,42 @@ def test_from_schema_versions( schema_versions_factory: SVFCallable, ): schema_versions = dict() - schema_versions.update(schema_versions_factory(Version(1))) - schema_versions.update(schema_versions_factory(Version(2))) - schema_versions.update(schema_versions_factory(Version(10))) - assert version.from_schema_versions(schema_versions) == resolved_version + schema_versions.update(schema_versions_factory(Versioner.V(1))) + schema_versions.update(schema_versions_factory(Versioner.V(2))) + schema_versions.update(schema_versions_factory(Versioner.V(10))) + assert Versioner.from_schema_versions(schema_versions, version) == resolved_version - @pytest.mark.parametrize("nonexisting_version", [Version(100), Version(2000)]) + @pytest.mark.parametrize("nonexisting_version", [Versioner.V(100), Versioner.V(2000)]) def test_from_schema_versions_nonexisting( self, nonexisting_version: Version, schema_versions_factory: SVFCallable, ): schema_versions = dict() - schema_versions.update(schema_versions_factory(Version(1))) + schema_versions.update(schema_versions_factory(Versioner.V(1))) with pytest.raises(VersionNotFoundException): - nonexisting_version.from_schema_versions(schema_versions) + Versioner.from_schema_versions(schema_versions, nonexisting_version) - @pytest.mark.parametrize("tag, resolved", [("latest", -1), (10, 10), ("20", 20), (-1, -1), ("-1", -1)]) + @pytest.mark.parametrize( + "tag, resolved", + [ + ("latest", Versioner.V(-1)), + (10, Versioner.V(10)), + ("20", Versioner.V(20)), + (-1, Versioner.V(-1)), + ("-1", Versioner.V(-1)), + ], + ) def test_factory_V(self, tag: VersionTag, resolved: int): - version = Version.V(tag=tag) + version = Versioner.V(tag=tag) assert version == resolved assert isinstance(version, Version) @pytest.mark.parametrize("tag", ["latest", 10, -1, "-1"]) - def test_validate(self, tag: VersionTag, version: Version): - version.validate_tag(tag=tag) + def test_validate(self, tag: VersionTag): + Versioner.validate_tag(tag=tag) @pytest.mark.parametrize("tag", ["invalid_version", 0, -20, "0"]) - def test_validate_invalid(self, tag: VersionTag, version: Version): + def test_validate_invalid(self, tag: VersionTag): with pytest.raises(InvalidVersion): - version.validate_tag(tag=tag) + Versioner.validate_tag(tag=tag) diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index d8c6a698f..a21d3bc00 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -4,7 +4,7 @@ """ from karapace.client import Path from karapace.config import DEFAULTS, read_config -from karapace.schema_models import SchemaType, ValidatedTypedSchema +from karapace.schema_models import SchemaType, ValidatedTypedSchema, Versioner from karapace.serialization import ( flatten_unions, get_subject_name, @@ -121,7 +121,7 @@ async def make_ser_deser(config_path: str, mock_client) -> SchemaRegistrySeriali async def test_happy_flow(default_config_path: Path): mock_registry_client = Mock() get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), int(1))) + get_latest_schema_future.set_result((1, ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), Versioner.V(1))) mock_registry_client.get_schema.return_value = get_latest_schema_future schema_for_id_one_future = asyncio.Future() schema_for_id_one_future.set_result((ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), [Subject("stub")])) @@ -313,7 +313,7 @@ def test_avro_json_write_accepts_json_encoded_data_without_tagged_unions() -> No async def test_serialization_fails(default_config_path: Path): mock_registry_client = Mock() get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), int(1))) + get_latest_schema_future.set_result((1, ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), Versioner.V(1))) mock_registry_client.get_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_registry_client)