Skip to content

Commit

Permalink
Fix Create Backup Consumption Logic
Browse files Browse the repository at this point in the history
The consumption logic is currently counting how many records it received in a
single batch returned from poll, and when it is empty it concludes that the
backup is successfully finished. However, there are meany reasons why a batch
returned by poll is empty, especially with timeouts applied to it. A consequence
of this is that a backup created at $t_1$ may contain more records than a backup
created at $t_2$ (without any external changes to the topic content, e.g.
compaction).

To fix this we have to use offset watermarks. With them we can determine if we
are done, or not. The patch now exposes the poll timeout, so that users can
increase it in case they encounter issues, and it uses a longer default poll
timeout to ensure that users are not going to see errors right away (increased
from 1 second to 1 minute).
  • Loading branch information
Fleshgrinder committed Feb 19, 2023
1 parent 8624bc6 commit 2d89696
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 49 deletions.
Empty file added karapace/backup/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions karapace/backup/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""
from __future__ import annotations

from datetime import timedelta
from isodate import duration_isoformat, parse_duration

__all__ = ["PollTimeout"]


# TODO @final (Python 3.8+)
class PollTimeout:
"""Specifies how long a single poll attempt may take while consuming the topic.
It may be necessary to adjust this value in case the cluster is slow. The value must be given in ISO8601 duration
format (e.g. `PT1.5S` for 1,500 milliseconds) and must be at least on second. Defaults to one minute.
"""

__slots__ = ("__value",)

def __init__(self, value: str | timedelta) -> None:
self.__value = value if isinstance(value, timedelta) else parse_duration(value)
if self.__value // timedelta(seconds=1) < 1:
raise ValueError(f"Poll timeout MUST be at least one second, got: {self}")

@classmethod
def default(cls) -> PollTimeout:
return cls(timedelta(minutes=1))

@classmethod
def of(cls, minutes: int = 0, seconds: int = 0, milliseconds: int = 0) -> PollTimeout:
"""Convenience function to avoid importing ``timedelta``."""
return PollTimeout(timedelta(minutes=minutes, seconds=seconds, milliseconds=milliseconds))

def __str__(self) -> str:
"""Returns the ISO8601 formatted value of this poll timeout."""
return duration_isoformat(self.__value)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(value='{self}')"

def to_milliseconds(self) -> int:
"""Returns this poll timeout in milliseconds, anything smaller than a milliseconds is ignored (no rounding)."""
return self.__value // timedelta(milliseconds=1)
78 changes: 78 additions & 0 deletions karapace/backup/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""
from kafka.structs import TopicPartition
from karapace.backup.consumer import PollTimeout

__all__ = ["BackupError", "PartitionCountError", "StaleConsumerError"]


class BackupError(Exception):
"""Baseclass for all backup errors."""


class PartitionCountError(BackupError):
pass


class StaleConsumerError(BackupError, RuntimeError):
"""Raised when the backup consumer does not make any progress and has not reached the last record in the topic."""

__slots__ = ("__topic_partition", "__start_offset", "__end_offset", "__last_offset", "__poll_timeout")

def __init__(
self,
topic_partition: TopicPartition,
start_offset: int,
end_offset: int,
current_offset: int,
poll_timeout: PollTimeout,
) -> None:
super().__init__(
f"{topic_partition.topic}:{topic_partition.partition}#{current_offset:,} ({start_offset:,},{end_offset:,})"
f" after {poll_timeout}"
)
self.__topic_partition = topic_partition
self.__start_offset = start_offset
self.__end_offset = end_offset
self.__last_offset = current_offset
self.__poll_timeout = poll_timeout

@property
def topic_partition(self) -> TopicPartition:
"""Gets the topic and partition that went stale during consumption."""
return self.__topic_partition

@property
def topic(self) -> str:
"""Gets the topic that went stale during consumption."""
return self.__topic_partition.topic

@property
def partition(self) -> int:
"""Gets the partition that went stale during consumption."""
return self.__topic_partition.partition

@property
def start_offset(self) -> int:
"""Gets the start offset of the topic and partition as determined at the start of the backup creation."""
return self.__start_offset

@property
def end_offset(self) -> int:
"""Gets the end offset of the topic and partition as determined at the start of the backup creation.
This is the offset of the last written record in the topic and partition, not the high watermark.
"""
return self.__end_offset

@property
def last_offset(self) -> int:
"""Gets the last offset of the topic and partition that was successfully consumed."""
return self.__last_offset

@property
def poll_timeout(self) -> PollTimeout:
"""Gets the poll timeout with which the consumer went stale while waiting for more records."""
return self.__poll_timeout
141 changes: 98 additions & 43 deletions karapace/schema_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from enum import Enum
from kafka import KafkaConsumer, KafkaProducer
from kafka.admin import KafkaAdminClient
from kafka.consumer.fetcher import ConsumerRecord
from kafka.errors import TopicAlreadyExistsError
from kafka.structs import PartitionMetadata
from kafka.structs import PartitionMetadata, TopicPartition
from karapace import constants
from karapace.anonymize_schemas import anonymize_avro
from karapace.backup.consumer import PollTimeout
from karapace.backup.errors import BackupError, PartitionCountError, StaleConsumerError
from karapace.config import Config, read_config
from karapace.key_format import KeyFormatter
from karapace.schema_reader import new_schema_topic_from_config
Expand All @@ -19,7 +22,7 @@
from pathlib import Path
from tempfile import mkstemp
from tenacity import retry, RetryCallState, stop_after_delay, wait_fixed
from typing import AbstractSet, Callable, IO, Optional, TextIO, Tuple, Union
from typing import AbstractSet, Callable, Collection, IO, Optional, TextIO, Tuple, Union

import argparse
import base64
Expand All @@ -41,14 +44,6 @@ class BackupVersion(Enum):
V2 = 2


class BackupError(Exception):
"""Backup Error"""


class PartitionCountError(BackupError):
pass


def __before_sleep(description: str) -> Callable[[RetryCallState], None]:
"""Returns a function to print a user-friendly message before going to sleep in retries.
Expand Down Expand Up @@ -349,25 +344,72 @@ def _restore_backup_version_2(self, producer: KafkaProducer, fp: IO) -> None:
value = base64.b16decode(hex_value.strip()).decode("utf8") if hex_value != "null" else hex_value
self._handle_restore_message(producer, (key, value))

def export(self, export_func, *, overwrite: Optional[bool] = None) -> None:
with _writer(self.backup_location, overwrite=overwrite) as fp:
with _consumer(self.config, self.topic_name) as consumer:
LOG.info("Starting schema backup read for topic: %r", self.topic_name)

topic_fully_consumed = False

fp.write(BACKUP_VERSION_2_MARKER)
while not topic_fully_consumed:
raw_msg = consumer.poll(timeout_ms=self.timeout_ms, max_records=1000)
topic_fully_consumed = len(raw_msg) == 0

for _, messages in raw_msg.items():
for message in messages:
ser = export_func(key_bytes=message.key, value_bytes=message.value)
if ser:
fp.write(ser)

LOG.info("Schema export written to %r", "stdout" if fp is sys.stdout else self.backup_location)
def create(
self,
serialize: Callable[[Optional[bytes], Optional[bytes]], str],
*,
poll_timeout: Optional[PollTimeout] = None,
overwrite: Optional[bool] = None,
) -> None:
"""Creates a backup of the configured topic.
FIXME the serialize callback is obviously dangerous as part of the public API, since it cannot be guaranteed
that it produces a string that is actually version 2 compatible. We anyway have to introduce a version 3,
and this public API can be fixed along with the introduction of it.
:param serialize: callback that encodes the consumer record into the target backup format.
:param poll_timeout: specifies the maximum time to wait for receiving records, if not records are received
within that time and the target offset has not been reached an exception is raised. Defaults to one minute.
:param overwrite: the output file if it exists.
:raises Exception: if consumption fails, concrete exception types are unknown, see Kafka implementation.
:raises FileExistsError: if ``overwrite`` is not ``True`` and the file already exists, or if the parent
directory of the file is not a directory.
:raises OSError: if writing fails or if the file already exists and is not actually a file.
:raises StaleConsumerError: if no records are received within the given ``poll_timeout`` and the target offset
has not been reached yet.
"""
if poll_timeout is None:
poll_timeout = PollTimeout.default()
poll_timeout_ms = poll_timeout.to_milliseconds()
topic = self.topic_name
with _writer(self.backup_location, overwrite=overwrite) as fp, _consumer(self.config, topic) as consumer:
(partition,) = consumer.partitions_for_topic(self.topic_name)
topic_partition = TopicPartition(self.topic_name, partition)
start_offset: int = consumer.beginning_offsets([topic_partition])[topic_partition]
end_offset: int = consumer.end_offsets([topic_partition])[topic_partition]
last_offset = start_offset
record_count = 0

fp.write(BACKUP_VERSION_2_MARKER)
if start_offset < end_offset: # non-empty topic
end_offset -= 1 # high watermark to actual end offset
print(
"Started backup of %s:%s (offset %s to %s)...",
topic,
partition,
f"{start_offset:,}",
f"{end_offset:,}",
file=sys.stderr,
)
while True:
records: Collection[ConsumerRecord] = consumer.poll(poll_timeout_ms).get(topic_partition, [])
if len(records) == 0:
raise StaleConsumerError(topic_partition, start_offset, end_offset, last_offset, poll_timeout)
record: ConsumerRecord
for record in records:
fp.write(serialize(record.key, record.value))
record_count += 1
last_offset = record.offset
if last_offset >= end_offset:
break
print(
"Finished backup of %s:%s to %r (backed up %s records).",
topic,
partition,
"stdout" if fp is sys.stdout else self.backup_location,
f"{record_count:,}",
file=sys.stderr,
)

def encode_key(self, key: Optional[Union[JsonData, str]]) -> Optional[bytes]:
if key == "null":
Expand Down Expand Up @@ -424,12 +466,14 @@ def parse_args():
parser_export_anonymized_avro_schemas = subparsers.add_parser(
"export-anonymized-avro-schemas", help="Export anonymized Avro schemas into a file"
)
for p in [parser_get, parser_restore, parser_export_anonymized_avro_schemas]:
for p in (parser_get, parser_restore, parser_export_anonymized_avro_schemas):
p.add_argument("--config", help="Configuration file path", required=True)
p.add_argument("--location", default="", help="File path for the backup file")
p.add_argument("--topic", help="Kafka topic name to be used", required=False)
for p in [parser_get, parser_export_anonymized_avro_schemas]:

for p in (parser_get, parser_export_anonymized_avro_schemas):
p.add_argument("--overwrite", action="store_true", help="Overwrite --location even if it exists.")
p.add_argument("--poll-timeout", help=PollTimeout.__doc__, type=PollTimeout)

return parser.parse_args()

Expand All @@ -443,18 +487,29 @@ def main() -> None:

sb = SchemaBackup(config, args.location, args.topic)

if args.command == "get":
sb.export(serialize_record, overwrite=args.overwrite)
elif args.command == "restore":
sb.restore_backup()
elif args.command == "export-anonymized-avro-schemas":
sb.export(anonymize_avro_schema_message, overwrite=args.overwrite)
else:
# Only reachable if a new subcommand was added that is not mapped above. There are other ways with argparse
# to handle this, but all rely on the programmer doing exactly the right thing. Only switching to another
# CLI framework would provide the ability to not handle this situation manually while ensuring that it is
# not possible to add a new subcommand without also providing a handler for it.
raise SystemExit(f"Entered unreachable code, unknown command: {args.command!r}")
try:
if args.command == "get":
sb.create(serialize_record, poll_timeout=args.poll_timeout, overwrite=args.overwrite)
elif args.command == "restore":
sb.restore_backup()
elif args.command == "export-anonymized-avro-schemas":
sb.create(anonymize_avro_schema_message, poll_timeout=args.poll_timeout, overwrite=args.overwrite)
else:
# Only reachable if a new subcommand was added that is not mapped above. There are other ways with
# argparse to handle this, but all rely on the programmer doing exactly the right thing. Only switching
# to another CLI framework would provide the ability to not handle this situation manually while
# ensuring that it is not possible to add a new subcommand without also providing a handler for it.
raise SystemExit(f"Entered unreachable code, unknown command: {args.command!r}")
except StaleConsumerError as e:
print(
f"The Kafka consumer did not receive any records for partition {e.partition} of topic {e.topic!r} "
f"within the poll timeout ({e.poll_timeout} seconds) while trying to reach offset {e.end_offset:,} "
f"(start was {e.start_offset:,} and the last seen offset was {e.last_offset:,}).\n"
"\n"
"Try increasing --poll-timeout to give the broker more time.",
file=sys.stderr,
)
raise SystemExit(1) from e
except KeyboardInterrupt as e:
# Not an error -- user choice -- and thus should not end up in a Python stacktrace.
raise SystemExit(2) from e
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
accept-types==0.4.1
aiohttp==3.8.3
aiokafka==0.7.2
isodate==0.6.1
jsonschema==3.2.0
networkx==2.5
protobuf==3.19.5
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ async def fixture_registry_async_client(
request: SubRequest,
registry_cluster: RegistryDescription,
loop: asyncio.AbstractEventLoop, # pylint: disable=unused-argument
) -> AsyncIterator[Client]:
) -> Client:

client = Client(
server_uri=registry_cluster.endpoint.to_url(),
Expand Down
Loading

0 comments on commit 2d89696

Please sign in to comment.