From 8cd9b55fafb1a1ecbebda3ca426da6add993e191 Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Mon, 25 Mar 2024 22:43:47 -0400 Subject: [PATCH 1/2] Add typing --- kafka/coordinator/assignors/abstract.py | 2 +- .../assignors/sticky/sticky_assignor.py | 1 - kafka/errors.py | 9 +- kafka/protocol/api.py | 14 +- kafka/protocol/struct.py | 13 +- kafka/record/_crc32c.py | 6 +- kafka/record/abc.py | 14 +- kafka/record/default_records.py | 87 ++-- kafka/record/legacy_records.py | 70 +-- kafka/record/memory_records.py | 35 +- kafka/record/util.py | 12 +- kafka/sasl/msk.py | 461 +++++++++--------- kafka/util.py | 11 +- 13 files changed, 372 insertions(+), 363 deletions(-) diff --git a/kafka/coordinator/assignors/abstract.py b/kafka/coordinator/assignors/abstract.py index a1fef3840..7c38907ef 100644 --- a/kafka/coordinator/assignors/abstract.py +++ b/kafka/coordinator/assignors/abstract.py @@ -12,7 +12,7 @@ class AbstractPartitionAssignor(object): partition counts which are always needed in assignors). """ - @abc.abstractproperty + @abc.abstractmethod def name(self): """.name should be a string identifying the assignor""" pass diff --git a/kafka/coordinator/assignors/sticky/sticky_assignor.py b/kafka/coordinator/assignors/sticky/sticky_assignor.py index 033642425..e75dc2561 100644 --- a/kafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/kafka/coordinator/assignors/sticky/sticky_assignor.py @@ -2,7 +2,6 @@ from collections import defaultdict, namedtuple from copy import deepcopy -from kafka.cluster import ClusterMetadata from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements from kafka.coordinator.assignors.sticky.sorted_set import SortedSet diff --git a/kafka/errors.py b/kafka/errors.py index cb3ff285f..d2f313c08 100644 --- a/kafka/errors.py +++ b/kafka/errors.py @@ -1,5 +1,6 @@ import inspect import sys +from typing import Any class KafkaError(RuntimeError): @@ -7,7 +8,7 @@ class KafkaError(RuntimeError): # whether metadata should be refreshed on error invalid_metadata = False - def __str__(self): + def __str__(self) -> str: if not self.args: return self.__class__.__name__ return '{}: {}'.format(self.__class__.__name__, @@ -65,7 +66,7 @@ class IncompatibleBrokerVersion(KafkaError): class CommitFailedError(KafkaError): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__( """Commit cannot be completed since the group has already rebalanced and assigned the partitions to another member. @@ -92,7 +93,7 @@ class BrokerResponseError(KafkaError): message = None description = None - def __str__(self): + def __str__(self) -> str: """Add errno to standard KafkaError str""" return '[Error {}] {}'.format( self.errno, @@ -509,7 +510,7 @@ def _iter_broker_errors(): kafka_errors = {x.errno: x for x in _iter_broker_errors()} -def for_code(error_code): +def for_code(error_code: int) -> Any: return kafka_errors.get(error_code, UnknownError) diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index 24cf61a62..6d6c6edca 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -52,22 +52,22 @@ class Request(Struct): FLEXIBLE_VERSION = False - @abc.abstractproperty + @abc.abstractmethod def API_KEY(self): """Integer identifier for api request""" pass - @abc.abstractproperty + @abc.abstractmethod def API_VERSION(self): """Integer of api request version""" pass - @abc.abstractproperty + @abc.abstractmethod def SCHEMA(self): """An instance of Schema() representing the request structure""" pass - @abc.abstractproperty + @abc.abstractmethod def RESPONSE_TYPE(self): """The Response class associated with the api request""" pass @@ -93,17 +93,17 @@ def parse_response_header(self, read_buffer): class Response(Struct): __metaclass__ = abc.ABCMeta - @abc.abstractproperty + @abc.abstractmethod def API_KEY(self): """Integer identifier for api request/response""" pass - @abc.abstractproperty + @abc.abstractmethod def API_VERSION(self): """Integer of api request/response version""" pass - @abc.abstractproperty + @abc.abstractmethod def SCHEMA(self): """An instance of Schema() representing the response structure""" pass diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py index eb08ac8ef..05189ed4d 100644 --- a/kafka/protocol/struct.py +++ b/kafka/protocol/struct.py @@ -1,4 +1,5 @@ from io import BytesIO +from typing import List, Union from kafka.protocol.abstract import AbstractType from kafka.protocol.types import Schema @@ -9,7 +10,7 @@ class Struct(AbstractType): SCHEMA = Schema() - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if len(args) == len(self.SCHEMA.fields): for i, name in enumerate(self.SCHEMA.names): self.__dict__[name] = args[i] @@ -36,23 +37,23 @@ def encode(cls, item): # pylint: disable=E0202 bits.append(field.encode(item[i])) return b''.join(bits) - def _encode_self(self): + def _encode_self(self) -> bytes: return self.SCHEMA.encode( [self.__dict__[name] for name in self.SCHEMA.names] ) @classmethod - def decode(cls, data): + def decode(cls, data: Union[BytesIO, bytes]) -> Union['ConsumerProtocolMemberAssignment', 'ConsumerProtocolMemberMetadata', 'FetchResponse_v0', 'StickyAssignorUserDataV1']: if isinstance(data, bytes): data = BytesIO(data) return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) - def get_item(self, name): + def get_item(self, name: str) -> Union[int, List[List[Union[int, str, bool, List[List[Union[int, List[int]]]]]]], str, List[List[Union[int, str]]]]: if name not in self.SCHEMA.names: raise KeyError("%s is not in the schema" % name) return self.__dict__[name] - def __repr__(self): + def __repr__(self) -> str: key_vals = [] for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields): key_vals.append(f'{name}={field.repr(self.__dict__[name])}') @@ -61,7 +62,7 @@ def __repr__(self): def __hash__(self): return hash(self.encode()) - def __eq__(self, other): + def __eq__(self, other: Union['ConsumerProtocolMemberAssignment', 'ConsumerProtocolMemberMetadata', 'MetadataRequest_v0', 'Message']) -> bool: if self.SCHEMA != other.SCHEMA: return False for attr in self.SCHEMA.names: diff --git a/kafka/record/_crc32c.py b/kafka/record/_crc32c.py index 6642b5bbe..f7743044c 100644 --- a/kafka/record/_crc32c.py +++ b/kafka/record/_crc32c.py @@ -97,7 +97,7 @@ _MASK = 0xFFFFFFFF -def crc_update(crc, data): +def crc_update(crc: int, data: bytes) -> int: """Update CRC-32C checksum with data. Args: crc: 32-bit checksum to update as long. @@ -116,7 +116,7 @@ def crc_update(crc, data): return crc ^ _MASK -def crc_finalize(crc): +def crc_finalize(crc: int) -> int: """Finalize CRC-32C checksum. This function should be called as last step of crc calculation. Args: @@ -127,7 +127,7 @@ def crc_finalize(crc): return crc & _MASK -def crc(data): +def crc(data: bytes) -> int: """Compute CRC-32C checksum of the data. Args: data: byte array, string or iterable over bytes. diff --git a/kafka/record/abc.py b/kafka/record/abc.py index f45176051..4ce5144d9 100644 --- a/kafka/record/abc.py +++ b/kafka/record/abc.py @@ -5,38 +5,38 @@ class ABCRecord: __metaclass__ = abc.ABCMeta __slots__ = () - @abc.abstractproperty + @abc.abstractmethod def offset(self): """ Absolute offset of record """ - @abc.abstractproperty + @abc.abstractmethod def timestamp(self): """ Epoch milliseconds """ - @abc.abstractproperty + @abc.abstractmethod def timestamp_type(self): """ CREATE_TIME(0) or APPEND_TIME(1) """ - @abc.abstractproperty + @abc.abstractmethod def key(self): """ Bytes key or None """ - @abc.abstractproperty + @abc.abstractmethod def value(self): """ Bytes value or None """ - @abc.abstractproperty + @abc.abstractmethod def checksum(self): """ Prior to v2 format CRC was contained in every message. This will be the checksum for v0 and v1 and None for v2 and above. """ - @abc.abstractproperty + @abc.abstractmethod def headers(self): """ If supported by version list of key-value tuples, or empty list if not supported by format. diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index 5045f31ee..91eb5c8a0 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -66,6 +66,7 @@ gzip_decode, snappy_decode, lz4_decode, zstd_decode ) import kafka.codec as codecs +from typing import Any, Callable, List, Optional, Tuple, Type, Union class DefaultRecordBase: @@ -105,7 +106,7 @@ class DefaultRecordBase: LOG_APPEND_TIME = 1 CREATE_TIME = 0 - def _assert_has_codec(self, compression_type): + def _assert_has_codec(self, compression_type: int) -> None: if compression_type == self.CODEC_GZIP: checker, name = codecs.has_gzip, "gzip" elif compression_type == self.CODEC_SNAPPY: @@ -124,7 +125,7 @@ class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): __slots__ = ("_buffer", "_header_data", "_pos", "_num_records", "_next_record_index", "_decompressed") - def __init__(self, buffer): + def __init__(self, buffer: Union[memoryview, bytes]) -> None: self._buffer = bytearray(buffer) self._header_data = self.HEADER_STRUCT.unpack_from(self._buffer) self._pos = self.HEADER_STRUCT.size @@ -133,11 +134,11 @@ def __init__(self, buffer): self._decompressed = False @property - def base_offset(self): + def base_offset(self) -> int: return self._header_data[0] @property - def magic(self): + def magic(self) -> int: return self._header_data[3] @property @@ -145,7 +146,7 @@ def crc(self): return self._header_data[4] @property - def attributes(self): + def attributes(self) -> int: return self._header_data[5] @property @@ -153,15 +154,15 @@ def last_offset_delta(self): return self._header_data[6] @property - def compression_type(self): + def compression_type(self) -> int: return self.attributes & self.CODEC_MASK @property - def timestamp_type(self): + def timestamp_type(self) -> int: return int(bool(self.attributes & self.TIMESTAMP_TYPE_MASK)) @property - def is_transactional(self): + def is_transactional(self) -> bool: return bool(self.attributes & self.TRANSACTIONAL_MASK) @property @@ -169,14 +170,14 @@ def is_control_batch(self): return bool(self.attributes & self.CONTROL_MASK) @property - def first_timestamp(self): + def first_timestamp(self) -> int: return self._header_data[7] @property def max_timestamp(self): return self._header_data[8] - def _maybe_uncompress(self): + def _maybe_uncompress(self) -> None: if not self._decompressed: compression_type = self.compression_type if compression_type != self.CODEC_NONE: @@ -196,7 +197,7 @@ def _maybe_uncompress(self): def _read_msg( self, - decode_varint=decode_varint): + decode_varint: Callable=decode_varint) -> "DefaultRecord": # Record => # Length => Varint # Attributes => Int8 @@ -272,11 +273,11 @@ def _read_msg( return DefaultRecord( offset, timestamp, self.timestamp_type, key, value, headers) - def __iter__(self): + def __iter__(self) -> "DefaultRecordBatch": self._maybe_uncompress() return self - def __next__(self): + def __next__(self) -> "DefaultRecord": if self._next_record_index >= self._num_records: if self._pos != len(self._buffer): raise CorruptRecordException( @@ -309,7 +310,7 @@ class DefaultRecord(ABCRecord): __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", "_headers") - def __init__(self, offset, timestamp, timestamp_type, key, value, headers): + def __init__(self, offset: int, timestamp: int, timestamp_type: int, key: Optional[bytes], value: bytes, headers: List[Union[Tuple[str, bytes], Any]]) -> None: self._offset = offset self._timestamp = timestamp self._timestamp_type = timestamp_type @@ -318,39 +319,39 @@ def __init__(self, offset, timestamp, timestamp_type, key, value, headers): self._headers = headers @property - def offset(self): + def offset(self) -> int: return self._offset @property - def timestamp(self): + def timestamp(self) -> int: """ Epoch milliseconds """ return self._timestamp @property - def timestamp_type(self): + def timestamp_type(self) -> int: """ CREATE_TIME(0) or APPEND_TIME(1) """ return self._timestamp_type @property - def key(self): + def key(self) -> Optional[bytes]: """ Bytes key or None """ return self._key @property - def value(self): + def value(self) -> bytes: """ Bytes value or None """ return self._value @property - def headers(self): + def headers(self) -> List[Union[Tuple[str, bytes], Any]]: return self._headers @property - def checksum(self): + def checksum(self) -> None: return None def __repr__(self): @@ -374,8 +375,8 @@ class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder): "_buffer") def __init__( - self, magic, compression_type, is_transactional, - producer_id, producer_epoch, base_sequence, batch_size): + self, magic: int, compression_type: int, is_transactional: Union[int, bool], + producer_id: int, producer_epoch: int, base_sequence: int, batch_size: int) -> None: assert magic >= 2 self._magic = magic self._compression_type = compression_type & self.CODEC_MASK @@ -393,7 +394,7 @@ def __init__( self._buffer = bytearray(self.HEADER_STRUCT.size) - def _get_attributes(self, include_compression_type=True): + def _get_attributes(self, include_compression_type: bool=True) -> int: attrs = 0 if include_compression_type: attrs |= self._compression_type @@ -403,13 +404,13 @@ def _get_attributes(self, include_compression_type=True): # Control batches are only created by Broker return attrs - def append(self, offset, timestamp, key, value, headers, + def append(self, offset: Union[int, str], timestamp: Optional[Union[int, str]], key: Optional[Union[str, bytes]], value: Optional[Union[str, bytes]], headers: List[Union[Tuple[str, bytes], Any, Tuple[str, None]]], # Cache for LOAD_FAST opcodes - encode_varint=encode_varint, size_of_varint=size_of_varint, - get_type=type, type_int=int, time_time=time.time, - byte_like=(bytes, bytearray, memoryview), - bytearray_type=bytearray, len_func=len, zero_len_varint=1 - ): + encode_varint: Callable=encode_varint, size_of_varint: Callable=size_of_varint, + get_type: Type[type]=type, type_int: Type[int]=int, time_time: Callable=time.time, + byte_like: Tuple[Type[bytes], Type[bytearray], Type[memoryview]]=(bytes, bytearray, memoryview), + bytearray_type: Type[bytearray]=bytearray, len_func: Callable=len, zero_len_varint: int=1 + ) -> Optional['DefaultRecordMetadata']: """ Write message to messageset buffer with MsgVersion 2 """ # Check types @@ -490,7 +491,7 @@ def append(self, offset, timestamp, key, value, headers, return DefaultRecordMetadata(offset, required_size, timestamp) - def write_header(self, use_compression_type=True): + def write_header(self, use_compression_type: bool=True) -> None: batch_len = len(self._buffer) self.HEADER_STRUCT.pack_into( self._buffer, 0, @@ -511,7 +512,7 @@ def write_header(self, use_compression_type=True): crc = calc_crc32c(self._buffer[self.ATTRIBUTES_OFFSET:]) struct.pack_into(">I", self._buffer, self.CRC_OFFSET, crc) - def _maybe_compress(self): + def _maybe_compress(self) -> bool: if self._compression_type != self.CODEC_NONE: self._assert_has_codec(self._compression_type) header_size = self.HEADER_STRUCT.size @@ -537,17 +538,17 @@ def _maybe_compress(self): return True return False - def build(self): + def build(self) -> bytearray: send_compressed = self._maybe_compress() self.write_header(send_compressed) return self._buffer - def size(self): + def size(self) -> int: """ Return current size of data written to buffer """ return len(self._buffer) - def size_in_bytes(self, offset, timestamp, key, value, headers): + def size_in_bytes(self, offset: int, timestamp: int, key: bytes, value: bytes, headers: List[Union[Tuple[str, bytes], Tuple[str, None]]]) -> int: if self._first_timestamp is not None: timestamp_delta = timestamp - self._first_timestamp else: @@ -561,7 +562,7 @@ def size_in_bytes(self, offset, timestamp, key, value, headers): return size_of_body + size_of_varint(size_of_body) @classmethod - def size_of(cls, key, value, headers): + def size_of(cls, key: bytes, value: bytes, headers: List[Union[Tuple[str, bytes], Tuple[str, None]]]) -> int: size = 0 # Key size if key is None: @@ -589,7 +590,7 @@ def size_of(cls, key, value, headers): return size @classmethod - def estimate_size_in_bytes(cls, key, value, headers): + def estimate_size_in_bytes(cls, key: bytes, value: bytes, headers: List[Tuple[str, bytes]]) -> int: """ Get the upper bound estimate on the size of record """ return ( @@ -602,28 +603,28 @@ class DefaultRecordMetadata: __slots__ = ("_size", "_timestamp", "_offset") - def __init__(self, offset, size, timestamp): + def __init__(self, offset: int, size: int, timestamp: int) -> None: self._offset = offset self._size = size self._timestamp = timestamp @property - def offset(self): + def offset(self) -> int: return self._offset @property - def crc(self): + def crc(self) -> None: return None @property - def size(self): + def size(self) -> int: return self._size @property - def timestamp(self): + def timestamp(self) -> int: return self._timestamp - def __repr__(self): + def __repr__(self) -> str: return ( "DefaultRecordMetadata(offset={!r}, size={!r}, timestamp={!r})" .format(self._offset, self._size, self._timestamp) diff --git a/kafka/record/legacy_records.py b/kafka/record/legacy_records.py index 9ab8873ca..b77799f4d 100644 --- a/kafka/record/legacy_records.py +++ b/kafka/record/legacy_records.py @@ -44,6 +44,7 @@ import struct import time + from kafka.record.abc import ABCRecord, ABCRecordBatch, ABCRecordBatchBuilder from kafka.record.util import calc_crc32 @@ -53,6 +54,7 @@ ) import kafka.codec as codecs from kafka.errors import CorruptRecordException, UnsupportedCodecError +from typing import Any, Iterator, List, Optional, Tuple, Union class LegacyRecordBase: @@ -115,7 +117,7 @@ class LegacyRecordBase: NO_TIMESTAMP = -1 - def _assert_has_codec(self, compression_type): + def _assert_has_codec(self, compression_type: int) -> None: if compression_type == self.CODEC_GZIP: checker, name = codecs.has_gzip, "gzip" elif compression_type == self.CODEC_SNAPPY: @@ -132,7 +134,7 @@ class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): __slots__ = ("_buffer", "_magic", "_offset", "_crc", "_timestamp", "_attributes", "_decompressed") - def __init__(self, buffer, magic): + def __init__(self, buffer: Union[memoryview, bytes], magic: int) -> None: self._buffer = memoryview(buffer) self._magic = magic @@ -147,7 +149,7 @@ def __init__(self, buffer, magic): self._decompressed = False @property - def timestamp_type(self): + def timestamp_type(self) -> Optional[int]: """0 for CreateTime; 1 for LogAppendTime; None if unsupported. Value is determined by broker; produced messages should always set to 0 @@ -161,14 +163,14 @@ def timestamp_type(self): return 0 @property - def compression_type(self): + def compression_type(self) -> int: return self._attributes & self.CODEC_MASK def validate_crc(self): crc = calc_crc32(self._buffer[self.MAGIC_OFFSET:]) return self._crc == crc - def _decompress(self, key_offset): + def _decompress(self, key_offset: int) -> bytes: # Copy of `_read_key_value`, but uses memoryview pos = key_offset key_size = struct.unpack_from(">i", self._buffer, pos)[0] @@ -195,7 +197,7 @@ def _decompress(self, key_offset): uncompressed = lz4_decode(data.tobytes()) return uncompressed - def _read_header(self, pos): + def _read_header(self, pos: int) -> Union[Tuple[int, int, int, int, int, None], Tuple[int, int, int, int, int, int]]: if self._magic == 0: offset, length, crc, magic_read, attrs = \ self.HEADER_STRUCT_V0.unpack_from(self._buffer, pos) @@ -205,7 +207,7 @@ def _read_header(self, pos): self.HEADER_STRUCT_V1.unpack_from(self._buffer, pos) return offset, length, crc, magic_read, attrs, timestamp - def _read_all_headers(self): + def _read_all_headers(self) -> List[Union[Tuple[Tuple[int, int, int, int, int, int], int], Tuple[Tuple[int, int, int, int, int, None], int]]]: pos = 0 msgs = [] buffer_len = len(self._buffer) @@ -215,7 +217,7 @@ def _read_all_headers(self): pos += self.LOG_OVERHEAD + header[1] # length return msgs - def _read_key_value(self, pos): + def _read_key_value(self, pos: int) -> Union[Tuple[None, bytes], Tuple[bytes, bytes]]: key_size = struct.unpack_from(">i", self._buffer, pos)[0] pos += self.KEY_LENGTH if key_size == -1: @@ -232,7 +234,7 @@ def _read_key_value(self, pos): value = self._buffer[pos:pos + value_size].tobytes() return key, value - def __iter__(self): + def __iter__(self) -> Iterator[LegacyRecordBase]: if self._magic == 1: key_offset = self.KEY_OFFSET_V1 else: @@ -286,7 +288,7 @@ class LegacyRecord(ABCRecord): __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", "_crc") - def __init__(self, offset, timestamp, timestamp_type, key, value, crc): + def __init__(self, offset: int, timestamp: Optional[int], timestamp_type: Optional[int], key: Optional[bytes], value: bytes, crc: int) -> None: self._offset = offset self._timestamp = timestamp self._timestamp_type = timestamp_type @@ -295,39 +297,39 @@ def __init__(self, offset, timestamp, timestamp_type, key, value, crc): self._crc = crc @property - def offset(self): + def offset(self) -> int: return self._offset @property - def timestamp(self): + def timestamp(self) -> Optional[int]: """ Epoch milliseconds """ return self._timestamp @property - def timestamp_type(self): + def timestamp_type(self) -> Optional[int]: """ CREATE_TIME(0) or APPEND_TIME(1) """ return self._timestamp_type @property - def key(self): + def key(self) -> Optional[bytes]: """ Bytes key or None """ return self._key @property - def value(self): + def value(self) -> bytes: """ Bytes value or None """ return self._value @property - def headers(self): + def headers(self) -> List[Any]: return [] @property - def checksum(self): + def checksum(self) -> int: return self._crc def __repr__(self): @@ -343,13 +345,13 @@ class LegacyRecordBatchBuilder(ABCRecordBatchBuilder, LegacyRecordBase): __slots__ = ("_magic", "_compression_type", "_batch_size", "_buffer") - def __init__(self, magic, compression_type, batch_size): + def __init__(self, magic: int, compression_type: int, batch_size: int) -> None: self._magic = magic self._compression_type = compression_type self._batch_size = batch_size self._buffer = bytearray() - def append(self, offset, timestamp, key, value, headers=None): + def append(self, offset: Union[int, str], timestamp: Optional[Union[int, str]], key: Optional[Union[bytes, str]], value: Optional[Union[str, bytes]], headers: None=None) -> Optional['LegacyRecordMetadata']: """ Append message to batch. """ assert not headers, "Headers not supported in v0/v1" @@ -388,8 +390,8 @@ def append(self, offset, timestamp, key, value, headers=None): return LegacyRecordMetadata(offset, crc, size, timestamp) - def _encode_msg(self, start_pos, offset, timestamp, key, value, - attributes=0): + def _encode_msg(self, start_pos: int, offset: int, timestamp: int, key: Optional[bytes], value: Optional[bytes], + attributes: int=0) -> int: """ Encode msg data into the `msg_buffer`, which should be allocated to at least the size of this message. """ @@ -437,7 +439,7 @@ def _encode_msg(self, start_pos, offset, timestamp, key, value, struct.pack_into(">I", buf, start_pos + self.CRC_OFFSET, crc) return crc - def _maybe_compress(self): + def _maybe_compress(self) -> bool: if self._compression_type: self._assert_has_codec(self._compression_type) data = bytes(self._buffer) @@ -464,19 +466,19 @@ def _maybe_compress(self): return True return False - def build(self): + def build(self) -> bytearray: """Compress batch to be ready for send""" self._maybe_compress() return self._buffer - def size(self): + def size(self) -> int: """ Return current size of data written to buffer """ return len(self._buffer) # Size calculations. Just copied Java's implementation - def size_in_bytes(self, offset, timestamp, key, value, headers=None): + def size_in_bytes(self, offset: int, timestamp: int, key: Optional[bytes], value: Optional[bytes], headers: None=None) -> int: """ Actual size of message to add """ assert not headers, "Headers not supported in v0/v1" @@ -484,7 +486,7 @@ def size_in_bytes(self, offset, timestamp, key, value, headers=None): return self.LOG_OVERHEAD + self.record_size(magic, key, value) @classmethod - def record_size(cls, magic, key, value): + def record_size(cls, magic: int, key: Optional[bytes], value: Optional[bytes]) -> int: message_size = cls.record_overhead(magic) if key is not None: message_size += len(key) @@ -493,7 +495,7 @@ def record_size(cls, magic, key, value): return message_size @classmethod - def record_overhead(cls, magic): + def record_overhead(cls, magic: int) -> int: assert magic in [0, 1], "Not supported magic" if magic == 0: return cls.RECORD_OVERHEAD_V0 @@ -501,7 +503,7 @@ def record_overhead(cls, magic): return cls.RECORD_OVERHEAD_V1 @classmethod - def estimate_size_in_bytes(cls, magic, compression_type, key, value): + def estimate_size_in_bytes(cls, magic: int, compression_type: int, key: bytes, value: bytes) -> int: """ Upper bound estimate of record size. """ assert magic in [0, 1], "Not supported magic" @@ -518,29 +520,29 @@ class LegacyRecordMetadata: __slots__ = ("_crc", "_size", "_timestamp", "_offset") - def __init__(self, offset, crc, size, timestamp): + def __init__(self, offset: int, crc: int, size: int, timestamp: int) -> None: self._offset = offset self._crc = crc self._size = size self._timestamp = timestamp @property - def offset(self): + def offset(self) -> int: return self._offset @property - def crc(self): + def crc(self) -> int: return self._crc @property - def size(self): + def size(self) -> int: return self._size @property - def timestamp(self): + def timestamp(self) -> int: return self._timestamp - def __repr__(self): + def __repr__(self) -> str: return ( "LegacyRecordMetadata(offset={!r}, crc={!r}, size={!r}," " timestamp={!r})".format( diff --git a/kafka/record/memory_records.py b/kafka/record/memory_records.py index 7a604887c..a915ed44f 100644 --- a/kafka/record/memory_records.py +++ b/kafka/record/memory_records.py @@ -23,8 +23,9 @@ from kafka.errors import CorruptRecordException from kafka.record.abc import ABCRecords -from kafka.record.legacy_records import LegacyRecordBatch, LegacyRecordBatchBuilder -from kafka.record.default_records import DefaultRecordBatch, DefaultRecordBatchBuilder +from kafka.record.legacy_records import LegacyRecordMetadata, LegacyRecordBatch, LegacyRecordBatchBuilder +from kafka.record.default_records import DefaultRecordMetadata, DefaultRecordBatch, DefaultRecordBatchBuilder +from typing import Any, List, Optional, Union class MemoryRecords(ABCRecords): @@ -38,7 +39,7 @@ class MemoryRecords(ABCRecords): __slots__ = ("_buffer", "_pos", "_next_slice", "_remaining_bytes") - def __init__(self, bytes_data): + def __init__(self, bytes_data: bytes) -> None: self._buffer = bytes_data self._pos = 0 # We keep one slice ahead so `has_next` will return very fast @@ -46,10 +47,10 @@ def __init__(self, bytes_data): self._remaining_bytes = None self._cache_next() - def size_in_bytes(self): + def size_in_bytes(self) -> int: return len(self._buffer) - def valid_bytes(self): + def valid_bytes(self) -> int: # We need to read the whole buffer to get the valid_bytes. # NOTE: in Fetcher we do the call after iteration, so should be fast if self._remaining_bytes is None: @@ -64,7 +65,7 @@ def valid_bytes(self): # NOTE: we cache offsets here as kwargs for a bit more speed, as cPython # will use LOAD_FAST opcode in this case - def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD): + def _cache_next(self, len_offset: int=LENGTH_OFFSET, log_overhead: int=LOG_OVERHEAD) -> None: buffer = self._buffer buffer_len = len(buffer) pos = self._pos @@ -88,12 +89,12 @@ def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD): self._next_slice = memoryview(buffer)[pos: slice_end] self._pos = slice_end - def has_next(self): + def has_next(self) -> bool: return self._next_slice is not None # NOTE: same cache for LOAD_FAST as above - def next_batch(self, _min_slice=MIN_SLICE, - _magic_offset=MAGIC_OFFSET): + def next_batch(self, _min_slice: int=MIN_SLICE, + _magic_offset: int=MAGIC_OFFSET) -> Optional[Union[DefaultRecordBatch, LegacyRecordBatch]]: next_slice = self._next_slice if next_slice is None: return None @@ -114,7 +115,7 @@ class MemoryRecordsBuilder: __slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed", "_bytes_written") - def __init__(self, magic, compression_type, batch_size): + def __init__(self, magic: int, compression_type: int, batch_size: int) -> None: assert magic in [0, 1, 2], "Not supported magic" assert compression_type in [0, 1, 2, 3, 4], "Not valid compression type" if magic >= 2: @@ -133,7 +134,7 @@ def __init__(self, magic, compression_type, batch_size): self._closed = False self._bytes_written = 0 - def append(self, timestamp, key, value, headers=[]): + def append(self, timestamp: Optional[int], key: Optional[Union[str, bytes]], value: Union[str, bytes], headers: List[Any]=[]) -> Optional[Union[DefaultRecordMetadata, LegacyRecordMetadata]]: """ Append a message to the buffer. Returns: RecordMetadata or None if unable to append @@ -150,7 +151,7 @@ def append(self, timestamp, key, value, headers=[]): self._next_offset += 1 return metadata - def close(self): + def close(self) -> None: # This method may be called multiple times on the same batch # i.e., on retries # we need to make sure we only close it out once @@ -162,25 +163,25 @@ def close(self): self._builder = None self._closed = True - def size_in_bytes(self): + def size_in_bytes(self) -> int: if not self._closed: return self._builder.size() else: return len(self._buffer) - def compression_rate(self): + def compression_rate(self) -> float: assert self._closed return self.size_in_bytes() / self._bytes_written - def is_full(self): + def is_full(self) -> bool: if self._closed: return True else: return self._builder.size() >= self._batch_size - def next_offset(self): + def next_offset(self) -> int: return self._next_offset - def buffer(self): + def buffer(self) -> bytes: assert self._closed return self._buffer diff --git a/kafka/record/util.py b/kafka/record/util.py index 3b712005d..d032151f1 100644 --- a/kafka/record/util.py +++ b/kafka/record/util.py @@ -1,13 +1,15 @@ import binascii from kafka.record._crc32c import crc as crc32c_py +from typing import Callable, Tuple + try: from crc32c import crc32c as crc32c_c except ImportError: crc32c_c = None -def encode_varint(value, write): +def encode_varint(value: int, write: Callable) -> int: """ Encode an integer to a varint presentation. See https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints on how those can be produced. @@ -60,7 +62,7 @@ def encode_varint(value, write): return i -def size_of_varint(value): +def size_of_varint(value: int) -> int: """ Number of bytes needed to encode an integer in variable-length format. """ value = (value << 1) ^ (value >> 63) @@ -85,7 +87,7 @@ def size_of_varint(value): return 10 -def decode_varint(buffer, pos=0): +def decode_varint(buffer: bytearray, pos: int=0) -> Tuple[int, int]: """ Decode an integer from a varint presentation. See https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints on how those can be produced. @@ -122,13 +124,13 @@ def decode_varint(buffer, pos=0): _crc32c = crc32c_c -def calc_crc32c(memview, _crc32c=_crc32c): +def calc_crc32c(memview: bytearray, _crc32c: Callable=_crc32c) -> int: """ Calculate CRC-32C (Castagnoli) checksum over a memoryview of data """ return _crc32c(memview) -def calc_crc32(memview): +def calc_crc32(memview: memoryview) -> int: """ Calculate simple CRC-32 checksum over a memoryview of data """ crc = binascii.crc32(memview) & 0xffffffff diff --git a/kafka/sasl/msk.py b/kafka/sasl/msk.py index 6d1bb74fb..ebea5dc5a 100644 --- a/kafka/sasl/msk.py +++ b/kafka/sasl/msk.py @@ -1,230 +1,231 @@ -import datetime -import hashlib -import hmac -import json -import string -import struct -import logging -import urllib - -from kafka.protocol.types import Int32 -import kafka.errors as Errors - -from botocore.session import Session as BotoSession # importing it in advance is not an option apparently... - - -def try_authenticate(self, future): - - session = BotoSession() - credentials = session.get_credentials().get_frozen_credentials() - client = AwsMskIamClient( - host=self.host, - access_key=credentials.access_key, - secret_key=credentials.secret_key, - region=session.get_config_variable('region'), - token=credentials.token, - ) - - msg = client.first_message() - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - self._send_bytes_blocking(size + msg) - data = self._recv_bytes_blocking(4) - data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1]) - except (ConnectionError, TimeoutError) as e: - logging.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError(f"{self}: {e}") - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - logging.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8')) - return future.success(True) - - -class AwsMskIamClient: - UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~' - - def __init__(self, host, access_key, secret_key, region, token=None): - """ - Arguments: - host (str): The hostname of the broker. - access_key (str): An AWS_ACCESS_KEY_ID. - secret_key (str): An AWS_SECRET_ACCESS_KEY. - region (str): An AWS_REGION. - token (Optional[str]): An AWS_SESSION_TOKEN if using temporary - credentials. - """ - self.algorithm = 'AWS4-HMAC-SHA256' - self.expires = '900' - self.hashfunc = hashlib.sha256 - self.headers = [ - ('host', host) - ] - self.version = '2020_10_22' - - self.service = 'kafka-cluster' - self.action = f'{self.service}:Connect' - - now = datetime.datetime.utcnow() - self.datestamp = now.strftime('%Y%m%d') - self.timestamp = now.strftime('%Y%m%dT%H%M%SZ') - - self.host = host - self.access_key = access_key - self.secret_key = secret_key - self.region = region - self.token = token - - @property - def _credential(self): - return '{0.access_key}/{0._scope}'.format(self) - - @property - def _scope(self): - return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self) - - @property - def _signed_headers(self): - """ - Returns (str): - An alphabetically sorted, semicolon-delimited list of lowercase - request header names. - """ - return ';'.join(sorted(k.lower() for k, _ in self.headers)) - - @property - def _canonical_headers(self): - """ - Returns (str): - A newline-delited list of header names and values. - Header names are lowercased. - """ - return '\n'.join(map(':'.join, self.headers)) + '\n' - - @property - def _canonical_request(self): - """ - Returns (str): - An AWS Signature Version 4 canonical request in the format: - \n - \n - \n - \n - \n - - """ - # The hashed_payload is always an empty string for MSK. - hashed_payload = self.hashfunc(b'').hexdigest() - return '\n'.join(( - 'GET', - '/', - self._canonical_querystring, - self._canonical_headers, - self._signed_headers, - hashed_payload, - )) - - @property - def _canonical_querystring(self): - """ - Returns (str): - A '&'-separated list of URI-encoded key/value pairs. - """ - params = [] - params.append(('Action', self.action)) - params.append(('X-Amz-Algorithm', self.algorithm)) - params.append(('X-Amz-Credential', self._credential)) - params.append(('X-Amz-Date', self.timestamp)) - params.append(('X-Amz-Expires', self.expires)) - if self.token: - params.append(('X-Amz-Security-Token', self.token)) - params.append(('X-Amz-SignedHeaders', self._signed_headers)) - - return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params) - - @property - def _signing_key(self): - """ - Returns (bytes): - An AWS Signature V4 signing key generated from the secret_key, date, - region, service, and request type. - """ - key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp) - key = self._hmac(key, self.region) - key = self._hmac(key, self.service) - key = self._hmac(key, 'aws4_request') - return key - - @property - def _signing_str(self): - """ - Returns (str): - A string used to sign the AWS Signature V4 payload in the format: - \n - \n - \n - - """ - canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest() - return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash)) - - def _uriencode(self, msg): - """ - Arguments: - msg (str): A string to URI-encode. - - Returns (str): - The URI-encoded version of the provided msg, following the encoding - rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode - """ - return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS) - - def _hmac(self, key, msg): - """ - Arguments: - key (bytes): A key to use for the HMAC digest. - msg (str): A value to include in the HMAC digest. - Returns (bytes): - An HMAC digest of the given key and msg. - """ - return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest() - - def first_message(self): - """ - Returns (bytes): - An encoded JSON authentication payload that can be sent to the - broker. - """ - signature = hmac.new( - self._signing_key, - self._signing_str.encode('utf-8'), - digestmod=self.hashfunc, - ).hexdigest() - msg = { - 'version': self.version, - 'host': self.host, - 'user-agent': 'kafka-python', - 'action': self.action, - 'x-amz-algorithm': self.algorithm, - 'x-amz-credential': self._credential, - 'x-amz-date': self.timestamp, - 'x-amz-signedheaders': self._signed_headers, - 'x-amz-expires': self.expires, - 'x-amz-signature': signature, - } - if self.token: - msg['x-amz-security-token'] = self.token - - return json.dumps(msg, separators=(',', ':')).encode('utf-8') +import datetime +import hashlib +import hmac +import json +import string +import struct +import logging +import urllib + +from kafka.protocol.types import Int32 +import kafka.errors as Errors + +from botocore.session import Session as BotoSession # importing it in advance is not an option apparently... +from typing import Optional + + +def try_authenticate(self, future): + + session = BotoSession() + credentials = session.get_credentials().get_frozen_credentials() + client = AwsMskIamClient( + host=self.host, + access_key=credentials.access_key, + secret_key=credentials.secret_key, + region=session.get_config_variable('region'), + token=credentials.token, + ) + + msg = client.first_message() + size = Int32.encode(len(msg)) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + self._send_bytes_blocking(size + msg) + data = self._recv_bytes_blocking(4) + data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1]) + except (ConnectionError, TimeoutError) as e: + logging.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError(f"{self}: {e}") + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + logging.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8')) + return future.success(True) + + +class AwsMskIamClient: + UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~' + + def __init__(self, host: str, access_key: str, secret_key: str, region: str, token: Optional[str]=None) -> None: + """ + Arguments: + host (str): The hostname of the broker. + access_key (str): An AWS_ACCESS_KEY_ID. + secret_key (str): An AWS_SECRET_ACCESS_KEY. + region (str): An AWS_REGION. + token (Optional[str]): An AWS_SESSION_TOKEN if using temporary + credentials. + """ + self.algorithm = 'AWS4-HMAC-SHA256' + self.expires = '900' + self.hashfunc = hashlib.sha256 + self.headers = [ + ('host', host) + ] + self.version = '2020_10_22' + + self.service = 'kafka-cluster' + self.action = f'{self.service}:Connect' + + now = datetime.datetime.utcnow() + self.datestamp = now.strftime('%Y%m%d') + self.timestamp = now.strftime('%Y%m%dT%H%M%SZ') + + self.host = host + self.access_key = access_key + self.secret_key = secret_key + self.region = region + self.token = token + + @property + def _credential(self) -> str: + return '{0.access_key}/{0._scope}'.format(self) + + @property + def _scope(self) -> str: + return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self) + + @property + def _signed_headers(self) -> str: + """ + Returns (str): + An alphabetically sorted, semicolon-delimited list of lowercase + request header names. + """ + return ';'.join(sorted(k.lower() for k, _ in self.headers)) + + @property + def _canonical_headers(self) -> str: + """ + Returns (str): + A newline-delited list of header names and values. + Header names are lowercased. + """ + return '\n'.join(map(':'.join, self.headers)) + '\n' + + @property + def _canonical_request(self) -> str: + """ + Returns (str): + An AWS Signature Version 4 canonical request in the format: + \n + \n + \n + \n + \n + + """ + # The hashed_payload is always an empty string for MSK. + hashed_payload = self.hashfunc(b'').hexdigest() + return '\n'.join(( + 'GET', + '/', + self._canonical_querystring, + self._canonical_headers, + self._signed_headers, + hashed_payload, + )) + + @property + def _canonical_querystring(self) -> str: + """ + Returns (str): + A '&'-separated list of URI-encoded key/value pairs. + """ + params = [] + params.append(('Action', self.action)) + params.append(('X-Amz-Algorithm', self.algorithm)) + params.append(('X-Amz-Credential', self._credential)) + params.append(('X-Amz-Date', self.timestamp)) + params.append(('X-Amz-Expires', self.expires)) + if self.token: + params.append(('X-Amz-Security-Token', self.token)) + params.append(('X-Amz-SignedHeaders', self._signed_headers)) + + return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params) + + @property + def _signing_key(self) -> bytes: + """ + Returns (bytes): + An AWS Signature V4 signing key generated from the secret_key, date, + region, service, and request type. + """ + key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp) + key = self._hmac(key, self.region) + key = self._hmac(key, self.service) + key = self._hmac(key, 'aws4_request') + return key + + @property + def _signing_str(self) -> str: + """ + Returns (str): + A string used to sign the AWS Signature V4 payload in the format: + \n + \n + \n + + """ + canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest() + return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash)) + + def _uriencode(self, msg: str) -> str: + """ + Arguments: + msg (str): A string to URI-encode. + + Returns (str): + The URI-encoded version of the provided msg, following the encoding + rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode + """ + return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS) + + def _hmac(self, key: bytes, msg: str) -> bytes: + """ + Arguments: + key (bytes): A key to use for the HMAC digest. + msg (str): A value to include in the HMAC digest. + Returns (bytes): + An HMAC digest of the given key and msg. + """ + return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest() + + def first_message(self) -> bytes: + """ + Returns (bytes): + An encoded JSON authentication payload that can be sent to the + broker. + """ + signature = hmac.new( + self._signing_key, + self._signing_str.encode('utf-8'), + digestmod=self.hashfunc, + ).hexdigest() + msg = { + 'version': self.version, + 'host': self.host, + 'user-agent': 'kafka-python', + 'action': self.action, + 'x-amz-algorithm': self.algorithm, + 'x-amz-credential': self._credential, + 'x-amz-date': self.timestamp, + 'x-amz-signedheaders': self._signed_headers, + 'x-amz-expires': self.expires, + 'x-amz-signature': signature, + } + if self.token: + msg['x-amz-security-token'] = self.token + + return json.dumps(msg, separators=(',', ':')).encode('utf-8') diff --git a/kafka/util.py b/kafka/util.py index 0c9c5ea62..968787341 100644 --- a/kafka/util.py +++ b/kafka/util.py @@ -1,11 +1,12 @@ import binascii import weakref +from typing import Callable, Optional MAX_INT = 2 ** 31 TO_SIGNED = 2 ** 32 -def crc32(data): +def crc32(data: bytes) -> int: crc = binascii.crc32(data) # py2 and py3 behave a little differently # CRC is encoded as a signed int in kafka protocol @@ -24,7 +25,7 @@ class WeakMethod: object_dot_method: A bound instance method (i.e. 'object.method'). """ - def __init__(self, object_dot_method): + def __init__(self, object_dot_method: Callable) -> None: try: self.target = weakref.ref(object_dot_method.__self__) except AttributeError: @@ -36,16 +37,16 @@ def __init__(self, object_dot_method): self.method = weakref.ref(object_dot_method.im_func) self._method_id = id(self.method()) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Optional[bytes]: """ Calls the method on target with args and kwargs. """ return self.method()(self.target(), *args, **kwargs) - def __hash__(self): + def __hash__(self) -> int: return hash(self.target) ^ hash(self.method) - def __eq__(self, other): + def __eq__(self, other: "WeakMethod") -> bool: if not isinstance(other, WeakMethod): return False return self._target_id == other._target_id and self._method_id == other._method_id From c18e3a0579b2aa27c5887e1626985251bdac5a4a Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Tue, 26 Mar 2024 09:24:01 -0400 Subject: [PATCH 2/2] define types as Struct for simplicity's sake --- kafka/protocol/struct.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py index 05189ed4d..65b3c8c63 100644 --- a/kafka/protocol/struct.py +++ b/kafka/protocol/struct.py @@ -4,6 +4,7 @@ from kafka.protocol.abstract import AbstractType from kafka.protocol.types import Schema + from kafka.util import WeakMethod @@ -43,7 +44,7 @@ def _encode_self(self) -> bytes: ) @classmethod - def decode(cls, data: Union[BytesIO, bytes]) -> Union['ConsumerProtocolMemberAssignment', 'ConsumerProtocolMemberMetadata', 'FetchResponse_v0', 'StickyAssignorUserDataV1']: + def decode(cls, data: Union[BytesIO, bytes]) -> "Struct": if isinstance(data, bytes): data = BytesIO(data) return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) @@ -62,7 +63,7 @@ def __repr__(self) -> str: def __hash__(self): return hash(self.encode()) - def __eq__(self, other: Union['ConsumerProtocolMemberAssignment', 'ConsumerProtocolMemberMetadata', 'MetadataRequest_v0', 'Message']) -> bool: + def __eq__(self, other: "Struct") -> bool: if self.SCHEMA != other.SCHEMA: return False for attr in self.SCHEMA.names: