From 9b68d3733499248959b049f815a17ab9af728124 Mon Sep 17 00:00:00 2001 From: Richard Fussenegger Date: Wed, 15 Feb 2023 15:27:32 +0100 Subject: [PATCH] Fix Create Backup Consumption Logic The consumption logic is currently counting how many records it received in a single batch returned from poll, and when it is empty it concludes that the backup is successfully finished. However, there are meany reasons why a batch returned by poll is empty, especially with timeouts applied to it. A consequence of this is that a backup created at $t_1$ may contain more records than a backup created at $t_2$ (without any external changes to the topic content, e.g. compaction). To fix this we have to use offset watermarks. With them we can determine if we are done, or not. The patch now exposes the poll timeout, so that users can increase it in case they encounter issues, and it uses a longer default poll timeout to ensure that users are not going to see errors right away (increased from 1 second to 1 minute). --- karapace/backup/__init__.py | 0 karapace/backup/consumer.py | 47 ++++++ karapace/backup/errors.py | 78 ++++++++++ karapace/schema_backup.py | 141 ++++++++++++------ requirements.txt | 1 + tests/integration/conftest.py | 2 +- tests/integration/test_schema_backup.py | 34 ++++- .../test_schema_backup_avro_export.py | 2 +- tests/unit/backup/test_consumer.py | 30 ++++ 9 files changed, 286 insertions(+), 49 deletions(-) create mode 100644 karapace/backup/__init__.py create mode 100644 karapace/backup/consumer.py create mode 100644 karapace/backup/errors.py create mode 100644 tests/unit/backup/test_consumer.py diff --git a/karapace/backup/__init__.py b/karapace/backup/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/karapace/backup/consumer.py b/karapace/backup/consumer.py new file mode 100644 index 000000000..055d6cf78 --- /dev/null +++ b/karapace/backup/consumer.py @@ -0,0 +1,47 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from __future__ import annotations + +from datetime import timedelta +from isodate import duration_isoformat, parse_duration +from typing import Union + +__all__ = ["PollTimeout"] + + +# TODO @final (Python 3.8+) +class PollTimeout: + """Specifies how long a single poll attempt may take while consuming the topic. + + It may be necessary to adjust this value in case the cluster is slow. The value must be given in ISO8601 duration + format (e.g. `PT1.5S` for 1,500 milliseconds) and must be at least on second. Defaults to one minute. + """ + + __slots__ = ("__value",) + + def __init__(self, value: Union[str, timedelta]) -> None: + self.__value = value if isinstance(value, timedelta) else parse_duration(value) + if self.__value // timedelta(seconds=1) < 1: + raise ValueError(f"Poll timeout MUST be at least one second, got: {self}") + + @classmethod + def default(cls) -> PollTimeout: + return cls(timedelta(minutes=1)) + + @classmethod + def of(cls, minutes: int = 0, seconds: int = 0, milliseconds: int = 0) -> PollTimeout: + """Convenience function to avoid importing ``timedelta``.""" + return PollTimeout(timedelta(minutes=minutes, seconds=seconds, milliseconds=milliseconds)) + + def __str__(self) -> str: + """Returns the ISO8601 formatted value of this poll timeout.""" + return duration_isoformat(self.__value) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(value='{self}')" + + def to_milliseconds(self) -> int: + """Returns this poll timeout in milliseconds, anything smaller than a milliseconds is ignored (no rounding).""" + return self.__value // timedelta(milliseconds=1) diff --git a/karapace/backup/errors.py b/karapace/backup/errors.py new file mode 100644 index 000000000..a3f7867b3 --- /dev/null +++ b/karapace/backup/errors.py @@ -0,0 +1,78 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from kafka.structs import TopicPartition +from karapace.backup.consumer import PollTimeout + +__all__ = ["BackupError", "PartitionCountError", "StaleConsumerError"] + + +class BackupError(Exception): + """Baseclass for all backup errors.""" + + +class PartitionCountError(BackupError): + pass + + +class StaleConsumerError(BackupError, RuntimeError): + """Raised when the backup consumer does not make any progress and has not reached the last record in the topic.""" + + __slots__ = ("__topic_partition", "__start_offset", "__end_offset", "__last_offset", "__poll_timeout") + + def __init__( + self, + topic_partition: TopicPartition, + start_offset: int, + end_offset: int, + current_offset: int, + poll_timeout: PollTimeout, + ) -> None: + super().__init__( + f"{topic_partition.topic}:{topic_partition.partition}#{current_offset:,} ({start_offset:,},{end_offset:,})" + f" after {poll_timeout}" + ) + self.__topic_partition = topic_partition + self.__start_offset = start_offset + self.__end_offset = end_offset + self.__last_offset = current_offset + self.__poll_timeout = poll_timeout + + @property + def topic_partition(self) -> TopicPartition: + """Gets the topic and partition that went stale during consumption.""" + return self.__topic_partition + + @property + def topic(self) -> str: + """Gets the topic that went stale during consumption.""" + return self.__topic_partition.topic + + @property + def partition(self) -> int: + """Gets the partition that went stale during consumption.""" + return self.__topic_partition.partition + + @property + def start_offset(self) -> int: + """Gets the start offset of the topic and partition as determined at the start of the backup creation.""" + return self.__start_offset + + @property + def end_offset(self) -> int: + """Gets the end offset of the topic and partition as determined at the start of the backup creation. + + This is the offset of the last written record in the topic and partition, not the high watermark. + """ + return self.__end_offset + + @property + def last_offset(self) -> int: + """Gets the last offset of the topic and partition that was successfully consumed.""" + return self.__last_offset + + @property + def poll_timeout(self) -> PollTimeout: + """Gets the poll timeout with which the consumer went stale while waiting for more records.""" + return self.__poll_timeout diff --git a/karapace/schema_backup.py b/karapace/schema_backup.py index ca1291983..bdfd72ecc 100644 --- a/karapace/schema_backup.py +++ b/karapace/schema_backup.py @@ -7,10 +7,13 @@ from enum import Enum from kafka import KafkaConsumer, KafkaProducer from kafka.admin import KafkaAdminClient +from kafka.consumer.fetcher import ConsumerRecord from kafka.errors import TopicAlreadyExistsError -from kafka.structs import PartitionMetadata +from kafka.structs import PartitionMetadata, TopicPartition from karapace import constants from karapace.anonymize_schemas import anonymize_avro +from karapace.backup.consumer import PollTimeout +from karapace.backup.errors import BackupError, PartitionCountError, StaleConsumerError from karapace.config import Config, read_config from karapace.key_format import KeyFormatter from karapace.schema_reader import new_schema_topic_from_config @@ -19,7 +22,7 @@ from pathlib import Path from tempfile import mkstemp from tenacity import retry, RetryCallState, stop_after_delay, wait_fixed -from typing import AbstractSet, Callable, IO, Optional, TextIO, Tuple, Union +from typing import AbstractSet, Callable, Collection, IO, Optional, TextIO, Tuple, Union import argparse import base64 @@ -41,14 +44,6 @@ class BackupVersion(Enum): V2 = 2 -class BackupError(Exception): - """Backup Error""" - - -class PartitionCountError(BackupError): - pass - - def __before_sleep(description: str) -> Callable[[RetryCallState], None]: """Returns a function to print a user-friendly message before going to sleep in retries. @@ -349,25 +344,72 @@ def _restore_backup_version_2(self, producer: KafkaProducer, fp: IO) -> None: value = base64.b16decode(hex_value.strip()).decode("utf8") if hex_value != "null" else hex_value self._handle_restore_message(producer, (key, value)) - def export(self, export_func, *, overwrite: Optional[bool] = None) -> None: - with _writer(self.backup_location, overwrite=overwrite) as fp: - with _consumer(self.config, self.topic_name) as consumer: - LOG.info("Starting schema backup read for topic: %r", self.topic_name) - - topic_fully_consumed = False - - fp.write(BACKUP_VERSION_2_MARKER) - while not topic_fully_consumed: - raw_msg = consumer.poll(timeout_ms=self.timeout_ms, max_records=1000) - topic_fully_consumed = len(raw_msg) == 0 - - for _, messages in raw_msg.items(): - for message in messages: - ser = export_func(key_bytes=message.key, value_bytes=message.value) - if ser: - fp.write(ser) - - LOG.info("Schema export written to %r", "stdout" if fp is sys.stdout else self.backup_location) + def create( + self, + serialize: Callable[[Optional[bytes], Optional[bytes]], str], + *, + poll_timeout: Optional[PollTimeout] = None, + overwrite: Optional[bool] = None, + ) -> None: + """Creates a backup of the configured topic. + + FIXME the serialize callback is obviously dangerous as part of the public API, since it cannot be guaranteed + that it produces a string that is actually version 2 compatible. We anyway have to introduce a version 3, + and this public API can be fixed along with the introduction of it. + + :param serialize: callback that encodes the consumer record into the target backup format. + :param poll_timeout: specifies the maximum time to wait for receiving records, if not records are received + within that time and the target offset has not been reached an exception is raised. Defaults to one minute. + :param overwrite: the output file if it exists. + :raises Exception: if consumption fails, concrete exception types are unknown, see Kafka implementation. + :raises FileExistsError: if ``overwrite`` is not ``True`` and the file already exists, or if the parent + directory of the file is not a directory. + :raises OSError: if writing fails or if the file already exists and is not actually a file. + :raises StaleConsumerError: if no records are received within the given ``poll_timeout`` and the target offset + has not been reached yet. + """ + if poll_timeout is None: + poll_timeout = PollTimeout.default() + poll_timeout_ms = poll_timeout.to_milliseconds() + topic = self.topic_name + with _writer(self.backup_location, overwrite=overwrite) as fp, _consumer(self.config, topic) as consumer: + (partition,) = consumer.partitions_for_topic(self.topic_name) + topic_partition = TopicPartition(self.topic_name, partition) + start_offset: int = consumer.beginning_offsets([topic_partition])[topic_partition] + end_offset: int = consumer.end_offsets([topic_partition])[topic_partition] + last_offset = start_offset + record_count = 0 + + fp.write(BACKUP_VERSION_2_MARKER) + if start_offset < end_offset: # non-empty topic + end_offset -= 1 # high watermark to actual end offset + print( + "Started backup of %s:%s (offset %s to %s)...", + topic, + partition, + f"{start_offset:,}", + f"{end_offset:,}", + file=sys.stderr, + ) + while True: + records: Collection[ConsumerRecord] = consumer.poll(poll_timeout_ms).get(topic_partition, []) + if len(records) == 0: + raise StaleConsumerError(topic_partition, start_offset, end_offset, last_offset, poll_timeout) + record: ConsumerRecord + for record in records: + fp.write(serialize(record.key, record.value)) + record_count += 1 + last_offset = record.offset + if last_offset >= end_offset: + break + print( + "Finished backup of %s:%s to %r (backed up %s records).", + topic, + partition, + "stdout" if fp is sys.stdout else self.backup_location, + f"{record_count:,}", + file=sys.stderr, + ) def encode_key(self, key: Optional[Union[JsonData, str]]) -> Optional[bytes]: if key == "null": @@ -424,12 +466,14 @@ def parse_args(): parser_export_anonymized_avro_schemas = subparsers.add_parser( "export-anonymized-avro-schemas", help="Export anonymized Avro schemas into a file" ) - for p in [parser_get, parser_restore, parser_export_anonymized_avro_schemas]: + for p in (parser_get, parser_restore, parser_export_anonymized_avro_schemas): p.add_argument("--config", help="Configuration file path", required=True) p.add_argument("--location", default="", help="File path for the backup file") p.add_argument("--topic", help="Kafka topic name to be used", required=False) - for p in [parser_get, parser_export_anonymized_avro_schemas]: + + for p in (parser_get, parser_export_anonymized_avro_schemas): p.add_argument("--overwrite", action="store_true", help="Overwrite --location even if it exists.") + p.add_argument("--poll-timeout", help=PollTimeout.__doc__, type=PollTimeout) return parser.parse_args() @@ -443,18 +487,29 @@ def main() -> None: sb = SchemaBackup(config, args.location, args.topic) - if args.command == "get": - sb.export(serialize_record, overwrite=args.overwrite) - elif args.command == "restore": - sb.restore_backup() - elif args.command == "export-anonymized-avro-schemas": - sb.export(anonymize_avro_schema_message, overwrite=args.overwrite) - else: - # Only reachable if a new subcommand was added that is not mapped above. There are other ways with argparse - # to handle this, but all rely on the programmer doing exactly the right thing. Only switching to another - # CLI framework would provide the ability to not handle this situation manually while ensuring that it is - # not possible to add a new subcommand without also providing a handler for it. - raise SystemExit(f"Entered unreachable code, unknown command: {args.command!r}") + try: + if args.command == "get": + sb.create(serialize_record, poll_timeout=args.poll_timeout, overwrite=args.overwrite) + elif args.command == "restore": + sb.restore_backup() + elif args.command == "export-anonymized-avro-schemas": + sb.create(anonymize_avro_schema_message, poll_timeout=args.poll_timeout, overwrite=args.overwrite) + else: + # Only reachable if a new subcommand was added that is not mapped above. There are other ways with + # argparse to handle this, but all rely on the programmer doing exactly the right thing. Only switching + # to another CLI framework would provide the ability to not handle this situation manually while + # ensuring that it is not possible to add a new subcommand without also providing a handler for it. + raise SystemExit(f"Entered unreachable code, unknown command: {args.command!r}") + except StaleConsumerError as e: + print( + f"The Kafka consumer did not receive any records for partition {e.partition} of topic {e.topic!r} " + f"within the poll timeout ({e.poll_timeout} seconds) while trying to reach offset {e.end_offset:,} " + f"(start was {e.start_offset:,} and the last seen offset was {e.last_offset:,}).\n" + "\n" + "Try increasing --poll-timeout to give the broker more time.", + file=sys.stderr, + ) + raise SystemExit(1) from e except KeyboardInterrupt as e: # Not an error -- user choice -- and thus should not end up in a Python stacktrace. raise SystemExit(2) from e diff --git a/requirements.txt b/requirements.txt index 5e3d0085d..92a6a2f6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ accept-types==0.4.1 aiohttp==3.8.3 aiokafka==0.7.2 +isodate==0.6.1 jsonschema==3.2.0 networkx==2.5 protobuf==3.19.5 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 14476eab8..ab67ffb61 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -404,7 +404,7 @@ async def fixture_registry_async_client( request: SubRequest, registry_cluster: RegistryDescription, loop: asyncio.AbstractEventLoop, # pylint: disable=unused-argument -) -> AsyncIterator[Client]: +) -> Client: client = Client( server_uri=registry_cluster.endpoint.to_url(), diff --git a/tests/integration/test_schema_backup.py b/tests/integration/test_schema_backup.py index 4ba960169..46c6c2e12 100644 --- a/tests/integration/test_schema_backup.py +++ b/tests/integration/test_schema_backup.py @@ -4,17 +4,20 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from datetime import timedelta from kafka import KafkaConsumer +from karapace.backup.errors import StaleConsumerError from karapace.client import Client from karapace.config import set_config_defaults from karapace.kafka_rest_apis import KafkaRestAdminClient from karapace.key_format import is_key_in_canonical_format -from karapace.schema_backup import SchemaBackup, serialize_record +from karapace.schema_backup import PollTimeout, SchemaBackup, serialize_record from karapace.utils import Expiration from pathlib import Path from tests.integration.utils.cluster import RegistryDescription from tests.integration.utils.kafka_server import KafkaServers from tests.utils import new_random_name +from unittest import mock import json import os @@ -36,7 +39,7 @@ async def insert_data(client: Client) -> str: async def test_backup_get( - registry_async_client, + registry_async_client: Client, kafka_servers: KafkaServers, tmp_path: Path, registry_cluster: RegistryDescription, @@ -52,7 +55,7 @@ async def test_backup_get( } ) sb = SchemaBackup(config, str(backup_location)) - sb.export(serialize_record) + sb.create(serialize_record) # The backup file has been created assert os.path.exists(backup_location) @@ -90,7 +93,7 @@ async def test_backup_restore_and_get_non_schema_topic( # Get the backup backup_location = tmp_path / "non_schemas_topic.log" sb = SchemaBackup(config, str(backup_location), topic_option=test_topic_name) - sb.export(serialize_record) + sb.create(serialize_record) # The backup file has been created assert os.path.exists(backup_location) @@ -215,3 +218,26 @@ async def test_backup_restore( _assert_canonical_key_format( bootstrap_servers=kafka_servers.bootstrap_servers, schemas_topic=registry_cluster.schemas_topic ) + + +async def test_stale_consumer( + kafka_servers: KafkaServers, + registry_async_client: Client, + registry_cluster: RegistryDescription, + tmp_path: Path, +) -> None: + await insert_data(registry_async_client) + config = set_config_defaults( + {"bootstrap_uri": kafka_servers.bootstrap_servers, "topic_name": registry_cluster.schemas_topic} + ) + with pytest.raises(StaleConsumerError) as e: + # The proper way to test this would be with quotas by throttling our client to death while using a very short + # poll timeout. However, we have no way to set up quotas because all Kafka clients available to us do not + # implement the necessary APIs. + with mock.patch(f"{KafkaConsumer.__module__}.{KafkaConsumer.__qualname__}._poll_once") as poll_once_mock: + poll_once_mock.return_value = {} + SchemaBackup(config, str(tmp_path / "backup")).create( + serialize_record, + poll_timeout=PollTimeout(timedelta(seconds=1)), + ) + assert str(e.value) == f"{registry_cluster.schemas_topic}:0#0 (0,0) after PT1S" diff --git a/tests/integration/test_schema_backup_avro_export.py b/tests/integration/test_schema_backup_avro_export.py index e4692f983..6290c071d 100644 --- a/tests/integration/test_schema_backup_avro_export.py +++ b/tests/integration/test_schema_backup_avro_export.py @@ -116,7 +116,7 @@ async def test_export_anonymized_avro_schemas( } ) sb = SchemaBackup(config, str(export_location)) - sb.export(anonymize_avro_schema_message) + sb.create(anonymize_avro_schema_message) # The export file has been created assert os.path.exists(export_location) diff --git a/tests/unit/backup/test_consumer.py b/tests/unit/backup/test_consumer.py new file mode 100644 index 000000000..0d8d4ed4f --- /dev/null +++ b/tests/unit/backup/test_consumer.py @@ -0,0 +1,30 @@ +""" +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from datetime import timedelta +from karapace.backup.consumer import PollTimeout +from typing import Union + +import pytest + + +class TestPollTimeout: + @pytest.mark.parametrize("it", ("PT0.999S", timedelta(milliseconds=999))) + def test_min_validation(self, it: Union[str, timedelta]) -> None: + with pytest.raises(ValueError) as e: + PollTimeout(it) + assert str(e.value) == "Poll timeout MUST be at least one second, got: PT0.999S" + + # Changing the default is not a breaking change, but the documentation needs to be adjusted! + def test_default(self) -> None: + assert str(PollTimeout.default()) == "PT1M" + + def test__str__(self) -> None: + assert str(PollTimeout.of(seconds=1, milliseconds=500)) == "PT1.5S" + + def test__repr__(self) -> None: + assert repr(PollTimeout.of(seconds=1, milliseconds=500)) == "PollTimeout(value='PT1.5S')" + + def test_to_milliseconds(self) -> None: + assert PollTimeout(timedelta(milliseconds=1000.5)).to_milliseconds() == 1000