Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Create Backup Consumption Logic #542

Merged
merged 1 commit into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added karapace/backup/__init__.py
Empty file.
45 changes: 45 additions & 0 deletions karapace/backup/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""
from __future__ import annotations

from datetime import timedelta
from isodate import duration_isoformat, parse_duration
Fleshgrinder marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["PollTimeout"]


class PollTimeout:
"""Specifies how long a single poll attempt may take while consuming the topic.

It may be necessary to adjust this value in case the cluster is slow. The value must be given in ISO8601 duration
format (e.g. `PT1.5S` for 1,500 milliseconds) and must be at least on second. Defaults to one minute.
"""

__slots__ = ("__value",)

def __init__(self, value: str | timedelta) -> None:
self.__value = value if isinstance(value, timedelta) else parse_duration(value)
if self.__value // timedelta(seconds=1) < 1:
raise ValueError(f"Poll timeout MUST be at least one second, got: {self}")

@classmethod
def default(cls) -> PollTimeout:
return cls(timedelta(minutes=1))

@classmethod
def of(cls, minutes: int = 0, seconds: int = 0, milliseconds: int = 0) -> PollTimeout:
"""Convenience function to avoid importing ``timedelta``."""
return PollTimeout(timedelta(minutes=minutes, seconds=seconds, milliseconds=milliseconds))

def __str__(self) -> str:
"""Returns the ISO8601 formatted value of this poll timeout."""
return duration_isoformat(self.__value)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(value='{self}')"

def to_milliseconds(self) -> int:
"""Returns this poll timeout in milliseconds, anything smaller than a milliseconds is ignored (no rounding)."""
return self.__value // timedelta(milliseconds=1)
78 changes: 78 additions & 0 deletions karapace/backup/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""
from kafka.structs import TopicPartition
from karapace.backup.consumer import PollTimeout

__all__ = ["BackupError", "PartitionCountError", "StaleConsumerError"]


class BackupError(Exception):
"""Baseclass for all backup errors."""


class PartitionCountError(BackupError):
pass


class StaleConsumerError(BackupError, RuntimeError):
"""Raised when the backup consumer does not make any progress and has not reached the last record in the topic."""

__slots__ = ("__topic_partition", "__start_offset", "__end_offset", "__last_offset", "__poll_timeout")

def __init__(
self,
topic_partition: TopicPartition,
start_offset: int,
end_offset: int,
current_offset: int,
poll_timeout: PollTimeout,
) -> None:
super().__init__(
f"{topic_partition.topic}:{topic_partition.partition}#{current_offset:,} ({start_offset:,},{end_offset:,})"
f" after {poll_timeout}"
)
self.__topic_partition = topic_partition
self.__start_offset = start_offset
self.__end_offset = end_offset
self.__last_offset = current_offset
self.__poll_timeout = poll_timeout

@property
def topic_partition(self) -> TopicPartition:
"""Gets the topic and partition that went stale during consumption."""
return self.__topic_partition

@property
def topic(self) -> str:
"""Gets the topic that went stale during consumption."""
return self.__topic_partition.topic

@property
def partition(self) -> int:
"""Gets the partition that went stale during consumption."""
return self.__topic_partition.partition

@property
def start_offset(self) -> int:
"""Gets the start offset of the topic and partition as determined at the start of the backup creation."""
return self.__start_offset

@property
def end_offset(self) -> int:
"""Gets the end offset of the topic and partition as determined at the start of the backup creation.

This is the offset of the last written record in the topic and partition, not the high watermark.
"""
return self.__end_offset

@property
def last_offset(self) -> int:
"""Gets the last offset of the topic and partition that was successfully consumed."""
return self.__last_offset

@property
def poll_timeout(self) -> PollTimeout:
"""Gets the poll timeout with which the consumer went stale while waiting for more records."""
return self.__poll_timeout
141 changes: 98 additions & 43 deletions karapace/schema_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from enum import Enum
from kafka import KafkaConsumer, KafkaProducer
from kafka.admin import KafkaAdminClient
from kafka.consumer.fetcher import ConsumerRecord
from kafka.errors import TopicAlreadyExistsError
from kafka.structs import PartitionMetadata
from kafka.structs import PartitionMetadata, TopicPartition
from karapace import constants
from karapace.anonymize_schemas import anonymize_avro
from karapace.backup.consumer import PollTimeout
from karapace.backup.errors import BackupError, PartitionCountError, StaleConsumerError
from karapace.config import Config, read_config
from karapace.key_format import KeyFormatter
from karapace.schema_reader import new_schema_topic_from_config
Expand All @@ -19,7 +22,7 @@
from pathlib import Path
from tempfile import mkstemp
from tenacity import retry, RetryCallState, stop_after_delay, wait_fixed
from typing import AbstractSet, Callable, IO, Optional, TextIO, Tuple, Union
from typing import AbstractSet, Callable, Collection, IO, Optional, TextIO, Tuple, Union

import argparse
import base64
Expand All @@ -41,14 +44,6 @@ class BackupVersion(Enum):
V2 = 2


class BackupError(Exception):
"""Backup Error"""


class PartitionCountError(BackupError):
pass


def __before_sleep(description: str) -> Callable[[RetryCallState], None]:
"""Returns a function to print a user-friendly message before going to sleep in retries.

Expand Down Expand Up @@ -349,25 +344,72 @@ def _restore_backup_version_2(self, producer: KafkaProducer, fp: IO) -> None:
value = base64.b16decode(hex_value.strip()).decode("utf8") if hex_value != "null" else hex_value
self._handle_restore_message(producer, (key, value))

def export(self, export_func, *, overwrite: Optional[bool] = None) -> None:
with _writer(self.backup_location, overwrite=overwrite) as fp:
with _consumer(self.config, self.topic_name) as consumer:
LOG.info("Starting schema backup read for topic: %r", self.topic_name)

topic_fully_consumed = False

fp.write(BACKUP_VERSION_2_MARKER)
while not topic_fully_consumed:
raw_msg = consumer.poll(timeout_ms=self.timeout_ms, max_records=1000)
topic_fully_consumed = len(raw_msg) == 0

for _, messages in raw_msg.items():
for message in messages:
ser = export_func(key_bytes=message.key, value_bytes=message.value)
if ser:
fp.write(ser)

LOG.info("Schema export written to %r", "stdout" if fp is sys.stdout else self.backup_location)
def create(
self,
serialize: Callable[[Optional[bytes], Optional[bytes]], str],
*,
poll_timeout: Optional[PollTimeout] = None,
overwrite: Optional[bool] = None,
) -> None:
"""Creates a backup of the configured topic.

FIXME the serialize callback is obviously dangerous as part of the public API, since it cannot be guaranteed
that it produces a string that is actually version 2 compatible. We anyway have to introduce a version 3,
and this public API can be fixed along with the introduction of it.

:param serialize: callback that encodes the consumer record into the target backup format.
:param poll_timeout: specifies the maximum time to wait for receiving records, if not records are received
within that time and the target offset has not been reached an exception is raised. Defaults to one minute.
:param overwrite: the output file if it exists.
:raises Exception: if consumption fails, concrete exception types are unknown, see Kafka implementation.
:raises FileExistsError: if ``overwrite`` is not ``True`` and the file already exists, or if the parent
directory of the file is not a directory.
:raises OSError: if writing fails or if the file already exists and is not actually a file.
:raises StaleConsumerError: if no records are received within the given ``poll_timeout`` and the target offset
has not been reached yet.
"""
if poll_timeout is None:
poll_timeout = PollTimeout.default()
poll_timeout_ms = poll_timeout.to_milliseconds()
topic = self.topic_name
with _writer(self.backup_location, overwrite=overwrite) as fp, _consumer(self.config, topic) as consumer:
(partition,) = consumer.partitions_for_topic(self.topic_name)
topic_partition = TopicPartition(self.topic_name, partition)
start_offset: int = consumer.beginning_offsets([topic_partition])[topic_partition]
end_offset: int = consumer.end_offsets([topic_partition])[topic_partition]
last_offset = start_offset
record_count = 0

fp.write(BACKUP_VERSION_2_MARKER)
if start_offset < end_offset: # non-empty topic
end_offset -= 1 # high watermark to actual end offset
print(
"Started backup of %s:%s (offset %s to %s)...",
topic,
partition,
f"{start_offset:,}",
f"{end_offset:,}",
file=sys.stderr,
)
while True:
records: Collection[ConsumerRecord] = consumer.poll(poll_timeout_ms).get(topic_partition, [])
if len(records) == 0:
raise StaleConsumerError(topic_partition, start_offset, end_offset, last_offset, poll_timeout)
record: ConsumerRecord
for record in records:
fp.write(serialize(record.key, record.value))
record_count += 1
last_offset = record.offset
if last_offset >= end_offset:
break
Comment on lines +402 to +404
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd compare the record offset to end offset in the loop and break immediately when reached. Now this could backup more than was decided as the end offset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is intentional, backing up more than the minimum is totally fine. Using a while loop would mean that we perform the same check twice in a row, however, using the do-while loop means that we perform only as many checks as necessary, while allowing to back up as much as possible, but at least what was in the topic when we started.

print(
"Finished backup of %s:%s to %r (backed up %s records).",
topic,
partition,
"stdout" if fp is sys.stdout else self.backup_location,
f"{record_count:,}",
file=sys.stderr,
)

def encode_key(self, key: Optional[Union[JsonData, str]]) -> Optional[bytes]:
if key == "null":
Expand Down Expand Up @@ -424,12 +466,14 @@ def parse_args():
parser_export_anonymized_avro_schemas = subparsers.add_parser(
"export-anonymized-avro-schemas", help="Export anonymized Avro schemas into a file"
)
for p in [parser_get, parser_restore, parser_export_anonymized_avro_schemas]:
for p in (parser_get, parser_restore, parser_export_anonymized_avro_schemas):
p.add_argument("--config", help="Configuration file path", required=True)
p.add_argument("--location", default="", help="File path for the backup file")
p.add_argument("--topic", help="Kafka topic name to be used", required=False)
for p in [parser_get, parser_export_anonymized_avro_schemas]:

for p in (parser_get, parser_export_anonymized_avro_schemas):
p.add_argument("--overwrite", action="store_true", help="Overwrite --location even if it exists.")
p.add_argument("--poll-timeout", help=PollTimeout.__doc__, type=PollTimeout)

return parser.parse_args()

Expand All @@ -443,18 +487,29 @@ def main() -> None:

sb = SchemaBackup(config, args.location, args.topic)

if args.command == "get":
sb.export(serialize_record, overwrite=args.overwrite)
elif args.command == "restore":
sb.restore_backup()
elif args.command == "export-anonymized-avro-schemas":
sb.export(anonymize_avro_schema_message, overwrite=args.overwrite)
else:
# Only reachable if a new subcommand was added that is not mapped above. There are other ways with argparse
# to handle this, but all rely on the programmer doing exactly the right thing. Only switching to another
# CLI framework would provide the ability to not handle this situation manually while ensuring that it is
# not possible to add a new subcommand without also providing a handler for it.
raise SystemExit(f"Entered unreachable code, unknown command: {args.command!r}")
try:
if args.command == "get":
sb.create(serialize_record, poll_timeout=args.poll_timeout, overwrite=args.overwrite)
elif args.command == "restore":
sb.restore_backup()
elif args.command == "export-anonymized-avro-schemas":
sb.create(anonymize_avro_schema_message, poll_timeout=args.poll_timeout, overwrite=args.overwrite)
else:
# Only reachable if a new subcommand was added that is not mapped above. There are other ways with
# argparse to handle this, but all rely on the programmer doing exactly the right thing. Only switching
# to another CLI framework would provide the ability to not handle this situation manually while
# ensuring that it is not possible to add a new subcommand without also providing a handler for it.
raise SystemExit(f"Entered unreachable code, unknown command: {args.command!r}")
except StaleConsumerError as e:
print(
f"The Kafka consumer did not receive any records for partition {e.partition} of topic {e.topic!r} "
f"within the poll timeout ({e.poll_timeout} seconds) while trying to reach offset {e.end_offset:,} "
f"(start was {e.start_offset:,} and the last seen offset was {e.last_offset:,}).\n"
"\n"
"Try increasing --poll-timeout to give the broker more time.",
file=sys.stderr,
)
raise SystemExit(1) from e
except KeyboardInterrupt as e:
# Not an error -- user choice -- and thus should not end up in a Python stacktrace.
raise SystemExit(2) from e
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
accept-types==0.4.1
aiohttp==3.8.3
aiokafka==0.7.2
isodate==0.6.1
jsonschema==3.2.0
networkx==2.5
protobuf==3.19.5
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ async def fixture_registry_async_client(
request: SubRequest,
registry_cluster: RegistryDescription,
loop: asyncio.AbstractEventLoop, # pylint: disable=unused-argument
) -> AsyncIterator[Client]:
) -> Client:

client = Client(
server_uri=registry_cluster.endpoint.to_url(),
Expand Down
Loading