diff --git a/cdc_kafka/avro.py b/cdc_kafka/avro.py deleted file mode 100644 index 23f0d63..0000000 --- a/cdc_kafka/avro.py +++ /dev/null @@ -1,166 +0,0 @@ -import json -from typing import Dict, Sequence, List, Any, Tuple, Optional, Callable - -from avro.schema import Schema -import confluent_kafka.avro - -from . import constants - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from .tracked_tables import TrackedField - - -class AvroSchemaGenerator(object): - def __init__(self, always_use_avro_longs: bool, - avro_type_spec_overrides: Dict[str, str | Dict[str, str | int]]) -> None: - self.always_use_avro_longs: bool = always_use_avro_longs - self.normalized_avro_type_overrides: Dict[Tuple[str, str, str], str | Dict[str, str | int]] = {} - for k, v in avro_type_spec_overrides.items(): - if k.count('.') != 2: - raise Exception(f'Avro type spec override "{k}" was incorrectly specified. Please key this config in ' - 'the form ..') - schema, table, column = k.split('.') - self.normalized_avro_type_overrides[(schema.lower(), table.lower(), column.lower())] = v - - def generate_key_schema(self, db_schema_name: str, db_table_name: str, - key_fields: Sequence['TrackedField']) -> Schema: - key_schema_fields = [self.get_record_field_schema( - db_schema_name, db_table_name, kf.name, kf.sql_type_name, kf.decimal_precision, kf.decimal_scale, False - ) for kf in key_fields] - schema_json = { - "name": f"{db_schema_name}_{db_table_name}_cdc__key", - "namespace": constants.AVRO_SCHEMA_NAMESPACE, - "type": "record", - "fields": key_schema_fields - } - return confluent_kafka.avro.loads(json.dumps(schema_json)) - - def generate_value_schema(self, db_schema_name: str, db_table_name: str, - value_fields: Sequence['TrackedField']) -> Schema: - # In CDC tables, all columns are nullable so that if the column is dropped from the source table, the capture - # instance need not be updated. We align with that by making the Avro value schema for all captured fields - # nullable (which also helps with maintaining future Avro schema compatibility). - value_schema_fields = [self.get_record_field_schema( - db_schema_name, db_table_name, vf.name, vf.sql_type_name, vf.decimal_precision, vf.decimal_scale, True - ) for vf in value_fields] - value_field_names = [f.name for f in value_fields] - value_fields_plus_metadata_fields = AvroSchemaGenerator.get_cdc_metadata_fields_avro_schemas( - db_schema_name, db_table_name, value_field_names) + value_schema_fields - schema_json = { - "name": f"{db_schema_name}_{db_table_name}_cdc__value", - "namespace": constants.AVRO_SCHEMA_NAMESPACE, - "type": "record", - "fields": value_fields_plus_metadata_fields - } - return confluent_kafka.avro.loads(json.dumps(schema_json)) - - def get_record_field_schema(self, db_schema_name: str, db_table_name: str, field_name: str, sql_type_name: str, - decimal_precision: int, decimal_scale: int, make_nullable: bool) -> Dict[str, Any]: - override_type = self.normalized_avro_type_overrides.get( - (db_schema_name.lower(), db_table_name.lower(), field_name.lower())) - if override_type: - avro_type = override_type - else: - if sql_type_name in ('decimal', 'numeric', 'money', 'smallmoney'): - if (not decimal_precision) or decimal_scale is None: - raise Exception(f"Field '{field_name}': For SQL decimal, money, or numeric types, the scale and " - f"precision must be provided.") - avro_type = { - "type": "bytes", - "logicalType": "decimal", - "precision": decimal_precision, - "scale": decimal_scale - } - elif sql_type_name == 'bigint': - avro_type = "long" - elif sql_type_name == 'bit': - avro_type = "boolean" - elif sql_type_name == 'float': - avro_type = "double" - elif sql_type_name == 'real': - avro_type = "float" - elif sql_type_name in ('int', 'smallint', 'tinyint'): - avro_type = "long" if self.always_use_avro_longs else "int" - # For date and time we don't respect always_use_avro_longs since the underlying type being `int` for these - # logical types is spelled out in the Avro spec: - elif sql_type_name == 'date': - avro_type = {"type": "int", "logicalType": "date"} - elif sql_type_name == 'time': - avro_type = {"type": "int", "logicalType": "time-millis"} - elif sql_type_name in ('datetime', 'datetime2', 'datetimeoffset', 'smalldatetime', - 'xml') + constants.SQL_STRING_TYPES: - avro_type = "string" - elif sql_type_name == 'uniqueidentifier': - avro_type = {"type": "string", "logicalType": "uuid"} - elif sql_type_name in ('binary', 'image', 'varbinary', 'rowversion'): - avro_type = "bytes" - else: - raise Exception(f"Field '{field_name}': I am unsure how to convert SQL type {sql_type_name} to Avro") - - if make_nullable: - return { - "name": field_name, - "type": [ - "null", - avro_type - ], - "default": None - } - else: - return { - "name": field_name, - "type": avro_type - } - - @staticmethod - # These fields are common to all change/snapshot data messages published to Kafka by this process - def get_cdc_metadata_fields_avro_schemas(db_schema_name: str, db_table_name: str, - source_field_names: List[str]) -> List[Dict[str, Any]]: - return [ - { - "name": constants.OPERATION_NAME, - "type": { - "type": "enum", - "name": f'{db_schema_name}_{db_table_name}{constants.OPERATION_NAME}', - "symbols": list(constants.CDC_OPERATION_NAME_TO_ID.keys()) - } - }, - { - # as ISO 8601 timestamp... either the change's tran_end_time OR the time the snapshot row was read: - "name": constants.EVENT_TIME_NAME, - "type": "string" - }, - { - "name": constants.LSN_NAME, - "type": ["null", "string"] - }, - { - "name": constants.SEQVAL_NAME, - "type": ["null", "string"] - }, - { - # Messages will list the names of all fields that were updated in the event (for snapshots or CDC insert - # records this will be all rows): - "name": constants.UPDATED_FIELDS_NAME, - "type": { - "type": "array", - "items": { - "type": "enum", - "name": f'{db_schema_name}_{db_table_name}{constants.UPDATED_FIELDS_NAME}', - "default": constants.UNRECOGNIZED_COLUMN_DEFAULT_NAME, - "symbols": [constants.UNRECOGNIZED_COLUMN_DEFAULT_NAME] + source_field_names - } - } - } - ] - - -def avro_transform_fn_from_sql_type(sql_type_name: str) -> Optional[Callable[[Any], Any]]: - if sql_type_name in ('datetime', 'datetime2', 'datetimeoffset', 'smalldatetime'): - # We have chosen to represent datetime values as ISO8601 strings rather than using the usual Avro convention of - # an int type + 'timestamp-millis' logical type that captures them as ms since the Unix epoch. This is because - # the latter presumes the time is in UTC, whereas we do not always know the TZ of datetimes we pull from the - # DB. It seems more 'faithful' to represent them exactly as they exist in the DB. - return lambda x: x and x.isoformat() - return None diff --git a/cdc_kafka/build_startup_state.py b/cdc_kafka/build_startup_state.py index 4addaa4..f5f2d35 100644 --- a/cdc_kafka/build_startup_state.py +++ b/cdc_kafka/build_startup_state.py @@ -1,4 +1,5 @@ import collections +import copy import datetime import logging import re @@ -8,7 +9,8 @@ from tabulate import tabulate from . import sql_query_subprocess, tracked_tables, sql_queries, kafka, progress_tracking, change_index, \ - constants, helpers, options, avro + constants, helpers, options +from .serializers.avro import AvroSchemaGenerator from .metric_reporting import accumulator logger = logging.getLogger(__name__) @@ -18,8 +20,7 @@ def build_tracked_tables_from_cdc_metadata( db_conn: pyodbc.Connection, metrics_accumulator: accumulator.Accumulator, topic_name_template: str, snapshot_table_include_config: str, snapshot_table_exclude_config: str, truncate_fields: Dict[str, int], capture_instance_names: List[str], db_row_batch_size: int, - sql_query_processor: sql_query_subprocess.SQLQueryProcessor, - schema_generator: avro.AvroSchemaGenerator, progress_tracker: progress_tracking.ProgressTracker + sql_query_processor: sql_query_subprocess.SQLQueryProcessor, progress_tracker: progress_tracking.ProgressTracker ) -> List[tracked_tables.TrackedTable]: result: List[tracked_tables.TrackedTable] = [] @@ -57,8 +58,8 @@ def build_tracked_tables_from_cdc_metadata( schema_name=schema_name, table_name=table_name, capture_instance_name=capture_instance_name) tracked_table = tracked_tables.TrackedTable( - db_conn, metrics_accumulator, sql_query_processor, schema_generator, schema_name, table_name, - capture_instance_name, topic_name, min_lsn, can_snapshot, db_row_batch_size, progress_tracker) + db_conn, metrics_accumulator, sql_query_processor, schema_name, table_name, capture_instance_name, + topic_name, min_lsn, can_snapshot, db_row_batch_size, progress_tracker) for (change_table_ordinal, column_name, sql_type_name, _, primary_key_ordinal, decimal_precision, decimal_scale, _) in fields: @@ -74,11 +75,10 @@ def build_tracked_tables_from_cdc_metadata( def determine_start_points_and_finalize_tables( kafka_client: kafka.KafkaClient, db_conn: pyodbc.Connection, tables: Iterable[tracked_tables.TrackedTable], - schema_generator: avro.AvroSchemaGenerator, progress_tracker: progress_tracking.ProgressTracker, - lsn_gap_handling: str, partition_count: int, replication_factor: int, - extra_topic_config: Dict[str, str | int], validation_mode: bool = False, - redo_snapshot_for_new_instance: bool = False, publish_duplicate_changes_from_new_instance: bool = False, - report_progress_only: bool = False + progress_tracker: progress_tracking.ProgressTracker, lsn_gap_handling: str, new_follow_start_point: str, + partition_count: int, replication_factor: int, extra_topic_config: Dict[str, str | int], + validation_mode: bool = False, redo_snapshot_for_new_instance: bool = False, + publish_duplicate_changes_from_new_instance: bool = False, report_progress_only: bool = False ) -> None: if validation_mode: for table in tables: @@ -94,6 +94,11 @@ def determine_start_points_and_finalize_tables( snapshot_progress: Optional[progress_tracking.ProgressEntry] changes_progress: Optional[progress_tracking.ProgressEntry] + with db_conn.cursor() as cursor: + q, _ = sql_queries.get_max_lsn() + cursor.execute(q) + db_max_lsn = cursor.fetchval() + for table in tables: kafka_client.begin_transaction() snapshot_progress, changes_progress = None, None @@ -126,8 +131,8 @@ def determine_start_points_and_finalize_tables( if redo_snapshot_for_new_instance: old_capture_instance_name = helpers.get_capture_instance_name(snapshot_progress.change_table_name) new_capture_instance_name = helpers.get_capture_instance_name(fq_change_table_name) - if ddl_change_requires_new_snapshot(db_conn, schema_generator, old_capture_instance_name, - new_capture_instance_name, table.fq_name): + if ddl_change_requires_new_snapshot(db_conn, old_capture_instance_name, new_capture_instance_name, + table.fq_name): logger.info('Will start new snapshot.') snapshot_progress = None else: @@ -160,8 +165,13 @@ def determine_start_points_and_finalize_tables( else: logger.info('Will NOT republish any change rows duplicated by the new capture instance.') + new_table_starting_index = copy.copy(change_index.LOWEST_CHANGE_INDEX) + if new_follow_start_point == options.NEW_FOLLOW_START_POINT_LATEST: + new_table_starting_index.lsn = db_max_lsn + if not (changes_progress and changes_progress.change_index): + logger.info('Beginning follow of new table %s from LSN %s', table.fq_name, new_table_starting_index) starting_change_index: change_index.ChangeIndex = \ - (changes_progress and changes_progress.change_index) or change_index.LOWEST_CHANGE_INDEX + (changes_progress and changes_progress.change_index) or new_table_starting_index starting_snapshot_index: Optional[Mapping[str, str | int]] = None if snapshot_progress: starting_snapshot_index = snapshot_progress.snapshot_index @@ -171,7 +181,7 @@ def determine_start_points_and_finalize_tables( options.LSN_GAP_HANDLING_IGNORE) else: table.finalize_table(starting_change_index, prior_change_table_max_index, starting_snapshot_index, - lsn_gap_handling, kafka_client.register_schemas, allow_progress_writes=True) + lsn_gap_handling, allow_progress_writes=True) if not table.snapshot_allowed: snapshot_state = '' @@ -194,9 +204,9 @@ def determine_start_points_and_finalize_tables( 'low key column values):\n%s\n%s tables total.', display_table, len(prior_progress_log_table_data)) -def ddl_change_requires_new_snapshot(db_conn: pyodbc.Connection, schema_generator: avro.AvroSchemaGenerator, - old_capture_instance_name: str, new_capture_instance_name: str, - source_table_fq_name: str, resnapshot_for_column_drops: bool = True) -> bool: +def ddl_change_requires_new_snapshot(db_conn: pyodbc.Connection, old_capture_instance_name: str, + new_capture_instance_name: str, source_table_fq_name: str, + resnapshot_for_column_drops: bool = True) -> bool: with db_conn.cursor() as cursor: cursor.execute(f'SELECT TOP 1 1 FROM [{constants.CDC_DB_SCHEMA_NAME}].[change_tables] ' f'WHERE capture_instance = ?', old_capture_instance_name) @@ -239,22 +249,45 @@ def ddl_change_requires_new_snapshot(db_conn: pyodbc.Connection, schema_generato for changed_col_name in changed_col_names: old_col = old_cols[changed_col_name] new_col = new_cols[changed_col_name] - # Even if the DB col type changed, a resnapshot is really only needed if the corresponding Avro type - # changes. An example would be a column "upgrading" from SMALLINT to INT: - db_schema, db_table = source_table_fq_name.split('.') - old_avro_type = schema_generator.get_record_field_schema( - db_schema, db_table, changed_col_name, old_col['sql_type_name'], old_col['decimal_precision'], - old_col['decimal_scale'], True) - new_avro_type = schema_generator.get_record_field_schema( - db_schema, db_table, changed_col_name, new_col['sql_type_name'], new_col['decimal_precision'], - new_col['decimal_scale'], True) - if old_col['is_computed'] != new_col['is_computed'] or old_avro_type != new_avro_type: - logger.info('Requiring re-snapshot for %s due to a data type change for column %s (type: %s, ' + + if old_col['is_computed'] != new_col['is_computed']: + logger.info('Requiring re-snapshot for %s due to an is_computed change for column %s (type: %s, ' 'is_computed: %s --> type: %s, is_computed: %s).', source_table_fq_name, changed_col_name, old_col['sql_type_name'], old_col['is_computed'], new_col['sql_type_name'], new_col['is_computed']) return True + # Even if the DB col type changed, a resnapshot is really only needed if the corresponding serialization + # schema changes. An example where we can skip a re-snapshot would be a column "upgrading" from SMALLINT + # to INT: + + # noinspection PyProtectedMember + if AvroSchemaGenerator._instance: # Will only exist if process was configured for Avro serialization + db_schema, db_table = source_table_fq_name.split('.') + # noinspection PyProtectedMember + old_avro_type = AvroSchemaGenerator._instance.get_record_field_schema( + db_schema, db_table, changed_col_name, old_col['sql_type_name'], old_col['decimal_precision'], + old_col['decimal_scale'], True) + # noinspection PyProtectedMember + new_avro_type = AvroSchemaGenerator._instance.get_record_field_schema( + db_schema, db_table, changed_col_name, new_col['sql_type_name'], new_col['decimal_precision'], + new_col['decimal_scale'], True) + if old_avro_type != new_avro_type: + logger.info('Requiring re-snapshot for %s due to an Avro schema change for column %s (type: %s, ' + 'is_computed: %s --> type: %s, is_computed: %s).', source_table_fq_name, + changed_col_name, old_col['sql_type_name'], old_col['is_computed'], + new_col['sql_type_name'], new_col['is_computed']) + return True + else: + # TODO - not yet supporting all the situations we could in this non-Avro case. Add nuance! + if not (old_col['sql_type_name'].lower().endswith('int') and + new_col['sql_type_name'].lower().endswith('int')): + logger.info('Requiring re-snapshot for %s due to a data type change for column %s (type: %s, ' + 'is_computed: %s --> type: %s, is_computed: %s).', source_table_fq_name, + changed_col_name, old_col['sql_type_name'], old_col['is_computed'], + new_col['sql_type_name'], new_col['is_computed']) + return True + for added_col_name in added_col_names: col_info = new_cols[added_col_name] if not col_info['is_nullable']: @@ -271,7 +304,7 @@ def ddl_change_requires_new_snapshot(db_conn: pyodbc.Connection, schema_generato # Gets the names of columns that appear in the first position of one or more unfiltered, non-disabled indexes: q, p = sql_queries.get_indexed_cols() - cursor.setinputsizes(p) + cursor.setinputsizes(p) # type: ignore[arg-type] cursor.execute(q, source_table_fq_name) indexed_cols: Set[str] = {row[0] for row in cursor.fetchall()} recently_added_cols: Optional[Set[str]] = None @@ -294,7 +327,7 @@ def ddl_change_requires_new_snapshot(db_conn: pyodbc.Connection, schema_generato cols_with_too_old_changes: Set[str] = set() cols_with_new_enough_changes: Set[str] = set() q, p = sql_queries.get_ddl_history_for_capture_table() - cursor.setinputsizes(p) + cursor.setinputsizes(p) # type: ignore[arg-type] cursor.execute(q, helpers.get_fq_change_table_name(old_capture_instance_name)) alter_re = re.compile( r'\W*alter\s+table\s+(?P
[\w\.\[\]]+)\s+add\s+(?P[\w\.\[\]]+)\s+(?P.*)', @@ -302,7 +335,7 @@ def ddl_change_requires_new_snapshot(db_conn: pyodbc.Connection, schema_generato for (ddl_command, age_seconds) in cursor.fetchall(): match = alter_re.match(ddl_command) if match and match.groupdict().get('column'): - col_name_lower = match.groupdict()['column'].lower() + col_name_lower = match.groupdict()['column'].lower().strip('[]') if age_seconds > constants.MAX_AGE_TO_PRESUME_ADDED_COL_IS_NULL_SECONDS: cols_with_too_old_changes.add(col_name_lower) else: diff --git a/cdc_kafka/change_index.py b/cdc_kafka/change_index.py index 1725ec5..9475804 100644 --- a/cdc_kafka/change_index.py +++ b/cdc_kafka/change_index.py @@ -50,25 +50,25 @@ def __repr__(self) -> str: # Converts from binary LSN/seqval to a string representation that is more friendly to some things that may # consume this data. The stringified form is also "SQL query ready" for pasting into SQL Server queries. - def to_avro_ready_dict(self) -> Dict[str, str]: + def as_dict(self) -> Dict[str, str]: return { constants.LSN_NAME: f'0x{self.lsn.hex()}', constants.SEQVAL_NAME: f'0x{self.seqval.hex()}', constants.OPERATION_NAME: constants.CDC_OPERATION_ID_TO_NAME[self.operation] } - @property - def is_probably_heartbeat(self) -> bool: - return self.seqval == HIGHEST_CHANGE_INDEX.seqval and self.operation == HIGHEST_CHANGE_INDEX.operation - @staticmethod - def from_avro_ready_dict(avro_dict: Dict[str, Any]) -> 'ChangeIndex': + def from_dict(source_dict: Dict[str, Any]) -> 'ChangeIndex': return ChangeIndex( - int(avro_dict[constants.LSN_NAME][2:], 16).to_bytes(10, "big"), - int(avro_dict[constants.SEQVAL_NAME][2:], 16).to_bytes(10, "big"), - constants.CDC_OPERATION_NAME_TO_ID[avro_dict[constants.OPERATION_NAME]] + int(source_dict[constants.LSN_NAME][2:], 16).to_bytes(10, "big"), + int(source_dict[constants.SEQVAL_NAME][2:], 16).to_bytes(10, "big"), + constants.CDC_OPERATION_NAME_TO_ID[source_dict[constants.OPERATION_NAME]] ) + @property + def is_probably_heartbeat(self) -> bool: + return self.seqval == HIGHEST_CHANGE_INDEX.seqval and self.operation == HIGHEST_CHANGE_INDEX.operation + LOWEST_CHANGE_INDEX = ChangeIndex(b'\x00' * 10, b'\x00' * 10, 0) HIGHEST_CHANGE_INDEX = ChangeIndex(b'\xff' * 10, b'\xff' * 10, 4) diff --git a/cdc_kafka/constants.py b/cdc_kafka/constants.py index b872b19..7847ab2 100644 --- a/cdc_kafka/constants.py +++ b/cdc_kafka/constants.py @@ -1,4 +1,5 @@ import datetime +from typing import Literal # Timing intervals @@ -21,14 +22,13 @@ KAFKA_REQUEST_TIMEOUT_SECS = 15 KAFKA_OAUTH_CB_POLL_TIMEOUT = 3 KAFKA_FULL_FLUSH_TIMEOUT_SECS = 30 -KAFKA_CONFIG_RELOAD_DELAY_SECS = 2 +KAFKA_CONFIG_RELOAD_DELAY_SECS = 1 # General MESSAGE_KEY_FIELD_NAME_WHEN_PK_ABSENT = '_row_hash' -DEFAULT_KEY_SCHEMA_COMPATIBILITY_LEVEL = 'FULL' -DEFAULT_VALUE_SCHEMA_COMPATIBILITY_LEVEL = 'FORWARD' -AVRO_SCHEMA_NAMESPACE = "cdc_to_kafka" +DEFAULT_KEY_SCHEMA_COMPATIBILITY_LEVEL: Literal["NONE", "FULL", "FORWARD", "BACKWARD"] = 'FULL' +DEFAULT_VALUE_SCHEMA_COMPATIBILITY_LEVEL: Literal["NONE", "FULL", "FORWARD", "BACKWARD"] = 'FORWARD' CDC_DB_SCHEMA_NAME = 'cdc' UNRECOGNIZED_COLUMN_DEFAULT_NAME = 'UNKNOWN_COL' VALIDATION_MAXIMUM_SAMPLE_SIZE_PER_TOPIC = 1_000_000 @@ -101,8 +101,3 @@ SNAPSHOT_LOGGING_MESSAGE = 'snapshot-logging' PROGRESS_DELETION_TOMBSTONE_MESSAGE = 'progress-deletion-tombstone' METRIC_REPORTING_MESSAGE = 'metric-reporting' - -ALL_KAFKA_MESSAGE_TYPES = ( - SINGLE_TABLE_CHANGE_MESSAGE, UNIFIED_TOPIC_CHANGE_MESSAGE, SINGLE_TABLE_SNAPSHOT_MESSAGE, - DELETION_CHANGE_TOMBSTONE_MESSAGE, CHANGE_PROGRESS_MESSAGE, SNAPSHOT_PROGRESS_MESSAGE, - PROGRESS_DELETION_TOMBSTONE_MESSAGE, METRIC_REPORTING_MESSAGE) diff --git a/cdc_kafka/kafka.py b/cdc_kafka/kafka.py index 74f00cd..f0feb65 100644 --- a/cdc_kafka/kafka.py +++ b/cdc_kafka/kafka.py @@ -1,18 +1,14 @@ import collections +import datetime import inspect -import io import json import logging import socket -import struct import time from types import TracebackType -from typing import List, Dict, Tuple, Any, Callable, Generator, Optional, Iterable, Set, Type +from typing import List, Dict, Tuple, Any, Generator, Optional, Set, Type -from avro.schema import Schema import confluent_kafka.admin -import confluent_kafka.avro -import fastavro from . import constants, kafka_oauth @@ -27,9 +23,8 @@ class KafkaClient(object): _instance = None def __init__(self, metrics_accumulator: 'accumulator.AccumulatorAbstract', bootstrap_servers: str, - schema_registry_url: str, extra_kafka_consumer_config: Dict[str, str | int], - extra_kafka_producer_config: Dict[str, str | int], disable_writing: bool = False, - transactional_id: Optional[str] = None) -> None: + extra_kafka_consumer_config: Dict[str, str | int], extra_kafka_producer_config: Dict[str, str | int], + disable_writing: bool = False, transactional_id: Optional[str] = None) -> None: if KafkaClient._instance is not None: raise Exception('KafkaClient class should be used as a singleton.') @@ -95,16 +90,7 @@ def __init__(self, metrics_accumulator: 'accumulator.AccumulatorAbstract', boots self._use_transactions = True self._producer: confluent_kafka.Producer = confluent_kafka.Producer(producer_config) - self._schema_registry: confluent_kafka.avro.CachedSchemaRegistryClient = \ - confluent_kafka.avro.CachedSchemaRegistryClient(schema_registry_url) self._admin: confluent_kafka.admin.AdminClient = confluent_kafka.admin.AdminClient(admin_config) - self._avro_serializer: confluent_kafka.avro.MessageSerializer = \ - confluent_kafka.avro.MessageSerializer(self._schema_registry) - self._avro_decoders: Dict[int, Callable[[io.BytesIO], Dict[str, Any]]] = dict() - self._schema_ids_to_names: Dict[int, str] = dict() - self._delivery_callbacks: Dict[str, List[Callable[ - [str, confluent_kafka.Message, Optional[Dict[str, Any]], Optional[Dict[str, Any]]], None - ]]] = collections.defaultdict(list) self._disable_writing = disable_writing self._creation_warned_topic_names: Set[str] = set() @@ -137,19 +123,13 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseEx time.sleep(1) # gives librdkafka threads more of a chance to exit properly before admin/producer are GC'd logger.info("Done.") - def register_delivery_callback(self, for_message_types: Iterable[str], - callback: Callable[[str, confluent_kafka.Message, Optional[Dict[str, Any]], - Optional[Dict[str, Any]]], None]) -> None: - for message_type in for_message_types: - if message_type not in constants.ALL_KAFKA_MESSAGE_TYPES: - raise Exception('Unrecognized message type: %s', message_type) - self._delivery_callbacks[message_type].append(callback) - # a return of None indicates the topic does not exist def get_topic_partition_count(self, topic_name: str) -> int: + if self._cluster_metadata.topics is None: + raise Exception('Unexpected state: no topic metadata') if topic_name not in self._cluster_metadata.topics: return 0 - return len(self._cluster_metadata.topics[topic_name].partitions) + return len(self._cluster_metadata.topics[topic_name].partitions or []) def begin_transaction(self) -> None: if not self._use_transactions: @@ -175,31 +155,24 @@ def commit_transaction(self) -> None: logger.debug('Kafka transaction commit from %s', f'{previous_frame[0]}, line {previous_frame[1]}') self._producer.commit_transaction() - def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int, value: Optional[Dict[str, Any]], - value_schema_id: int, message_type: str, copy_to_unified_topics: Optional[List[str]] = None, + def produce(self, topic: str, key: Optional[bytes], value: Optional[bytes], message_type: str, + copy_to_unified_topics: Optional[List[str]] = None, event_datetime: Optional[datetime.datetime] = None, + change_lsn: Optional[bytes] = None, operation_id: Optional[int] = None, extra_headers: Optional[Dict[str, str | bytes]] = None) -> None: if self._disable_writing: return start_time = time.perf_counter() - if key is None: - key_ser = None - else: - key_ser = self._avro_serializer.encode_record_with_schema_id(key_schema_id, key, True) - if value is None: # a deletion tombstone probably - value_ser = None + if event_datetime: + delivery_cb = lambda _, msg: self._metrics_accumulator.kafka_delivery_callback(msg, event_datetime) else: - value_ser = self._avro_serializer.encode_record_with_schema_id(value_schema_id, value, False) + delivery_cb = None while True: try: - # the callback function receives the binary-serialized payload, so instead of specifying it - # directly as the delivery callback we wrap it in a lambda that also captures and passes the - # original not-yet-serialized key and value so that we don't have to re-deserialize it later: self._producer.produce( - topic=topic, value=value_ser, key=key_ser, - on_delivery=lambda err, msg: self._delivery_callback(message_type, err, msg, key, value), + topic=topic, value=value, key=key, on_delivery=delivery_cb, headers={'cdc_to_kafka_message_type': message_type, **(extra_headers or {})} ) break @@ -210,8 +183,9 @@ def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int, logger.error('The following exception occurred producing to topic %s', topic) raise - elapsed = (time.perf_counter() - start_time) - self._metrics_accumulator.register_kafka_produce(elapsed, value, message_type) + elapsed = time.perf_counter() - start_time + self._metrics_accumulator.register_kafka_produce(elapsed, message_type, event_datetime, + change_lsn, operation_id) if copy_to_unified_topics: for unified_topic in copy_to_unified_topics: @@ -220,9 +194,7 @@ def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int, while True: try: self._producer.produce( - topic=unified_topic, value=value_ser, key=key_ser, - on_delivery=lambda err, msg: self._delivery_callback( - constants.UNIFIED_TOPIC_CHANGE_MESSAGE, err, msg, key, value), + topic=unified_topic, value=value, key=key, on_delivery=delivery_cb, headers={'cdc_to_kafka_message_type': constants.UNIFIED_TOPIC_CHANGE_MESSAGE, 'cdc_to_kafka_original_topic': topic, **(extra_headers or {})} ) @@ -234,8 +206,9 @@ def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int, logger.error('The following exception occurred producing to topic %s', unified_topic) raise - elapsed = (time.perf_counter() - start_time) - self._metrics_accumulator.register_kafka_produce(elapsed, value, constants.UNIFIED_TOPIC_CHANGE_MESSAGE) + elapsed = time.perf_counter() - start_time + self._metrics_accumulator.register_kafka_produce(elapsed, constants.UNIFIED_TOPIC_CHANGE_MESSAGE, + event_datetime, change_lsn, operation_id) def consume_all(self, topic_name: str) -> Generator[confluent_kafka.Message, None, None]: part_count = self.get_topic_partition_count(topic_name) @@ -271,8 +244,9 @@ def consume_all(self, topic_name: str) -> Generator[confluent_kafka.Message, Non continue if msg.error(): # noinspection PyProtectedMember - if msg.error().code() == confluent_kafka.KafkaError._PARTITION_EOF: - finished_parts[msg.partition()] = True + if (msg.error().code() == confluent_kafka.KafkaError._PARTITION_EOF # type: ignore[union-attr] + and msg.partition() is not None): + finished_parts[msg.partition()] = True # type: ignore[index] if all(finished_parts): break continue @@ -283,7 +257,6 @@ def consume_all(self, topic_name: str) -> Generator[confluent_kafka.Message, Non if ctr % 100000 == 0: logger.debug('consume_all has yielded %s messages so far from topic %s', ctr, topic_name) - self._set_decoded_msg(msg) yield msg consumer.close() @@ -322,15 +295,16 @@ def consume_bounded(self, topic_name: str, approx_max_recs: int, continue if msg.error(): # noinspection PyProtectedMember - if msg.error().code() == confluent_kafka.KafkaError._PARTITION_EOF: - finished_parts[msg.partition()] = True + if (msg.error().code() == confluent_kafka.KafkaError._PARTITION_EOF # type: ignore[union-attr] + and msg.partition() is not None): + finished_parts[msg.partition()] = True # type: ignore[index] if all(finished_parts): break continue else: raise confluent_kafka.KafkaException(msg.error()) - if msg.offset() > boundary_watermarks[msg.partition()][1]: - finished_parts[msg.partition()] = True + if msg.offset() > boundary_watermarks[msg.partition()][1]: # type: ignore[index, operator] + finished_parts[msg.partition()] = True # type: ignore[index] if all(finished_parts): break continue @@ -339,45 +313,27 @@ def consume_bounded(self, topic_name: str, approx_max_recs: int, if ctr % 100000 == 0: logger.debug('consume_bounded has yielded %s messages so far from topic %s', ctr, topic_name) - self._set_decoded_msg(msg) yield msg consumer.close() - def _set_decoded_msg(self, msg: confluent_kafka.Message) -> None: - # noinspection PyArgumentList - for msg_part, setter in ((msg.key(), msg.set_key), (msg.value(), msg.set_value)): - if msg_part is not None: - payload = io.BytesIO(msg_part) - _, schema_id = struct.unpack('>bI', payload.read(5)) - if schema_id not in self._avro_decoders: - self._set_decoder(schema_id) - decoder = self._avro_decoders[schema_id] - to_set = decoder(payload) - to_set['__avro_schema_name'] = self._schema_ids_to_names[schema_id] - setter(to_set) - payload.close() - - def _set_decoder(self, schema_id: int) -> None: - reg_schema = self._schema_registry.get_by_id(schema_id) - schema = fastavro.parse_schema(reg_schema.to_json()) - self._schema_ids_to_names[schema_id] = reg_schema.name - self._avro_decoders[schema_id] = lambda p: fastavro.schemaless_reader(p, schema) - def create_topic(self, topic_name: str, partition_count: int, replication_factor: Optional[int] = None, extra_config: Optional[Dict[str, str | int]] = None) -> None: if self._disable_writing: return if not replication_factor: + if self._cluster_metadata.brokers is None: + raise Exception('Unexpected state: no brokers metadata') replication_factor = min(len(self._cluster_metadata.brokers), 3) extra_config = extra_config or {} topic_config = {**{'cleanup.policy': 'compact'}, **extra_config} + topic_config_str = {k: str(v) for k, v in topic_config.items()} logger.info('Creating Kafka topic "%s" with %s partitions, replication factor %s, and config: %s', topic_name, - partition_count, replication_factor, json.dumps(topic_config)) - topic = confluent_kafka.admin.NewTopic(topic_name, partition_count, replication_factor, config=topic_config) + partition_count, replication_factor, json.dumps(topic_config_str)) + topic = confluent_kafka.admin.NewTopic(topic_name, partition_count, replication_factor, config=topic_config_str) self._admin.create_topics([topic])[topic_name].result() time.sleep(constants.KAFKA_CONFIG_RELOAD_DELAY_SECS) self._refresh_cluster_metadata() @@ -413,60 +369,10 @@ def get_topic_watermarks(self, topic_names: List[str]) -> Dict[str, List[Tuple[i def get_topic_config(self, topic_name: str) -> Any: resource = confluent_kafka.admin.ConfigResource( - restype=confluent_kafka.admin.ConfigResource.Type.TOPIC, name=topic_name) + restype=confluent_kafka.admin.ConfigResource.Type.TOPIC, name=topic_name) # type: ignore[attr-defined] result = self._admin.describe_configs([resource]) return result[resource].result() - # returns (key schema ID, value schema ID) - def register_schemas(self, topic_name: str, key_schema: Optional[Schema], value_schema: Schema, - key_schema_compatibility_level: str = constants.DEFAULT_KEY_SCHEMA_COMPATIBILITY_LEVEL, - value_schema_compatibility_level: str = constants.DEFAULT_VALUE_SCHEMA_COMPATIBILITY_LEVEL) \ - -> Tuple[int, int]: - # TODO: it turns out that if you try to re-register a schema that was previously registered but later superseded - # (e.g. in the case of adding and then later deleting a column), the schema registry will accept that and return - # you the previously-registered schema ID without updating the `latest` version associated with the registry - # subject, or verifying that the change is Avro-compatible. It seems like the way to handle this, per - # https://github.com/confluentinc/schema-registry/issues/1685, would be to detect the condition and delete the - # subject-version-number of that schema before re-registering it. Since subject-version deletion is not - # available in the `CachedSchemaRegistryClient` we use here--and since this is a rare case--I'm explicitly - # choosing to punt on it for the moment. The Confluent lib does now have a newer `SchemaRegistryClient` class - # which supports subject-version deletion, but changing this code to use it appears to be a non-trivial task. - - key_subject, value_subject = topic_name + '-key', topic_name + '-value' - registered = False - - if key_schema: - key_schema_id, current_key_schema, _ = self._schema_registry.get_latest_schema(key_subject) - if (current_key_schema is None or current_key_schema != key_schema) and not self._disable_writing: - logger.info('Key schema for subject %s does not exist or is outdated; registering now.', key_subject) - key_schema_id = self._schema_registry.register(key_subject, key_schema) - logger.debug('Schema registered for subject %s: %s', key_subject, key_schema) - if current_key_schema is None: - time.sleep(constants.KAFKA_CONFIG_RELOAD_DELAY_SECS) - self._schema_registry.update_compatibility(key_schema_compatibility_level, key_subject) - registered = True - else: - key_schema_id = 0 - - value_schema_id, current_value_schema, _ = self._schema_registry.get_latest_schema(value_subject) - if (current_value_schema is None or current_value_schema != value_schema) and not self._disable_writing: - logger.info('Value schema for subject %s does not exist or is outdated; registering now.', value_subject) - value_schema_id = self._schema_registry.register(value_subject, value_schema) - logger.debug('Schema registered for subject %s: %s', value_subject, value_schema) - if current_value_schema is None: - time.sleep(constants.KAFKA_CONFIG_RELOAD_DELAY_SECS) - self._schema_registry.update_compatibility(value_schema_compatibility_level, value_subject) - registered = True - - if registered: - # some older versions of the Confluent schema registry have a bug that leads to duplicate schema IDs in - # some circumstances; delay a bit if we actually registered a new schema, to give the registry a chance - # to become consistent (see https://github.com/confluentinc/schema-registry/pull/1003 and linked issues - # for context): - time.sleep(constants.KAFKA_CONFIG_RELOAD_DELAY_SECS) - - return key_schema_id, value_schema_id - def _refresh_cluster_metadata(self) -> None: self._cluster_metadata = self._get_cluster_metadata() @@ -478,14 +384,6 @@ def _get_cluster_metadata(self) -> confluent_kafka.admin.ClusterMetadata: raise Exception(f'Cluster metadata request to Kafka timed out') return metadata - def _delivery_callback(self, message_type: str, err: confluent_kafka.KafkaError, message: confluent_kafka.Message, - original_key: Optional[Dict[str, Any]], original_value: Optional[Dict[str, Any]]) -> None: - if err is not None: - raise confluent_kafka.KafkaException(f'Delivery error on topic {message.topic()}: {err}') - - for cb in self._delivery_callbacks[message_type]: - cb(message_type, message, original_key, original_value) - @staticmethod def _raise_kafka_error(err: confluent_kafka.KafkaError) -> None: if err.fatal(): diff --git a/cdc_kafka/kafka_oauth/__init__.py b/cdc_kafka/kafka_oauth/__init__.py index 44b1b0a..0fb1734 100644 --- a/cdc_kafka/kafka_oauth/__init__.py +++ b/cdc_kafka/kafka_oauth/__init__.py @@ -2,7 +2,6 @@ import importlib import os from abc import ABC, abstractmethod - from typing import TypeVar, Type, Tuple, Optional KafkaOauthProviderAbstractType = TypeVar('KafkaOauthProviderAbstractType', bound='KafkaOauthProviderAbstract') diff --git a/cdc_kafka/kafka_oauth/aws_msk.py b/cdc_kafka/kafka_oauth/aws_msk.py index 8adfdc0..1de2911 100644 --- a/cdc_kafka/kafka_oauth/aws_msk.py +++ b/cdc_kafka/kafka_oauth/aws_msk.py @@ -4,7 +4,7 @@ import os from typing import Tuple, TypeVar, Type, Optional -from aws_msk_iam_sasl_signer import MSKAuthTokenProvider +from aws_msk_iam_sasl_signer import MSKAuthTokenProvider # type: ignore[import-untyped] from . import KafkaOauthProviderAbstract diff --git a/cdc_kafka/main.py b/cdc_kafka/main.py index 4c5ed26..699b664 100644 --- a/cdc_kafka/main.py +++ b/cdc_kafka/main.py @@ -12,12 +12,13 @@ import pyodbc from . import clock_sync, kafka, tracked_tables, constants, options, validation, change_index, progress_tracking, \ - sql_query_subprocess, sql_queries, helpers, avro + sql_query_subprocess, sql_queries, helpers from .build_startup_state import build_tracked_tables_from_cdc_metadata, determine_start_points_and_finalize_tables, \ get_latest_capture_instances_by_fq_name, CaptureInstanceMetadata from .metric_reporting import accumulator from typing import TYPE_CHECKING + if TYPE_CHECKING: from . import parsed_row @@ -27,13 +28,13 @@ def run() -> None: logger.info('Starting...') opts: argparse.Namespace - opts, reporters = options.get_options_and_metrics_reporters() + opts, reporters, serializer = options.get_options_and_metrics_reporters() + disable_writes: bool = opts.run_validations or opts.report_progress_only logger.debug('Parsed configuration: %s', json.dumps(vars(opts))) - if not (opts.schema_registry_url and opts.kafka_bootstrap_servers and opts.db_conn_string - and opts.kafka_transactional_id): - raise Exception('Arguments schema_registry_url, kafka_bootstrap_servers, db_conn_string, and ' + if not (opts.kafka_bootstrap_servers and opts.db_conn_string and opts.kafka_transactional_id): + raise Exception('Arguments kafka_bootstrap_servers, db_conn_string, and ' 'kafka_transactional_id are all required.') redo_snapshot_for_new_instance: bool = \ @@ -52,9 +53,6 @@ def run() -> None: metrics_accumulator: accumulator.Accumulator = accumulator.Accumulator( db_conn, clock_syncer, opts.metrics_namespace, opts.process_hostname) - schema_generator: avro.AvroSchemaGenerator = avro.AvroSchemaGenerator( - opts.always_use_avro_longs, opts.avro_type_spec_overrides) - capture_instances_by_fq_name: Dict[str, CaptureInstanceMetadata] = get_latest_capture_instances_by_fq_name( db_conn, opts.capture_instance_version_strategy, opts.capture_instance_version_regex, opts.table_include_regex, opts.table_exclude_regex) @@ -67,35 +65,34 @@ def run() -> None: for ci in capture_instances_by_fq_name.values()] with kafka.KafkaClient( - metrics_accumulator, opts.kafka_bootstrap_servers, opts.schema_registry_url, - opts.extra_kafka_consumer_config, opts.extra_kafka_producer_config, - disable_writing=opts.run_validations or opts.report_progress_only, - transactional_id=opts.kafka_transactional_id + metrics_accumulator, opts.kafka_bootstrap_servers, opts.extra_kafka_consumer_config, + opts.extra_kafka_producer_config, disable_writing=disable_writes, + transactional_id=opts.kafka_transactional_id ) as kafka_client: progress_tracker = progress_tracking.ProgressTracker( - kafka_client, opts.progress_topic_name, opts.process_hostname, opts.snapshot_logging_topic_name + kafka_client, serializer, opts.progress_topic_name, opts.process_hostname, + opts.snapshot_logging_topic_name ) - kafka_client.register_delivery_callback(( - constants.SINGLE_TABLE_CHANGE_MESSAGE, constants.UNIFIED_TOPIC_CHANGE_MESSAGE, - constants.SINGLE_TABLE_SNAPSHOT_MESSAGE, constants.DELETION_CHANGE_TOMBSTONE_MESSAGE - ), metrics_accumulator.kafka_delivery_callback) tables: List[tracked_tables.TrackedTable] = build_tracked_tables_from_cdc_metadata( db_conn, metrics_accumulator, opts.topic_name_template, opts.snapshot_table_include_regex, - opts.snapshot_table_exclude_regex, opts.truncate_fields, capture_instance_names, opts.db_row_batch_size, - sql_query_processor, schema_generator, progress_tracker) + opts.snapshot_table_exclude_regex, opts.truncate_fields, capture_instance_names, + opts.db_row_batch_size, sql_query_processor, progress_tracker) - capture_instance_to_topic_map: Dict[str, str] = { - t.capture_instance_name: t.topic_name for t in tables} + capture_instance_to_topic_map: Dict[str, str] = {t.capture_instance_name: t.topic_name for t in tables} determine_start_points_and_finalize_tables( - kafka_client, db_conn, tables, schema_generator, progress_tracker, opts.lsn_gap_handling, + kafka_client, db_conn, tables, progress_tracker, opts.lsn_gap_handling, opts.new_follow_start_point, opts.partition_count, opts.replication_factor, opts.extra_topic_config, opts.run_validations, - redo_snapshot_for_new_instance, publish_duplicate_changes_from_new_instance, opts.report_progress_only) + redo_snapshot_for_new_instance, publish_duplicate_changes_from_new_instance, + opts.report_progress_only) if opts.report_progress_only: exit(0) + for table in tables: + serializer.register_table(table) + table_to_unified_topics_map: Dict[str, List[str]] = collections.defaultdict(list) unified_topic_to_tables_map: Dict[str, List[tracked_tables.TrackedTable]] = collections.defaultdict(list) @@ -108,7 +105,7 @@ def run() -> None: matched_tables = [table for table in tables if compiled_regex.match(table.fq_name)] if matched_tables: for matched_table in matched_tables: - table_to_unified_topics_map[matched_table.fq_name].append(unified_topic_name) + table_to_unified_topics_map[matched_table.topic_name].append(unified_topic_name) unified_topic_to_tables_map[unified_topic_name].append(matched_table) part_count = kafka_client.get_topic_partition_count(unified_topic_name) if part_count: @@ -130,7 +127,7 @@ def run() -> None: # those and the source DB data. It takes a while; probably don't run this on very large datasets! if opts.run_validations: validator: validation.Validator = validation.Validator( - kafka_client, tables, progress_tracker, unified_topic_to_tables_map) + kafka_client, tables, progress_tracker, serializer, unified_topic_to_tables_map) validator.run() exit(0) @@ -222,20 +219,25 @@ def poll_periodic_tasks() -> bool: for t in tables: if not t.snapshot_complete: + last_row_retrieved: Optional[parsed_row.ParsedRow] = None for row in t.retrieve_snapshot_query_results(): - kafka_client.produce(row.destination_topic, row.key_dict, - row.avro_key_schema_id, row.value_dict, - row.avro_value_schema_id, - constants.SINGLE_TABLE_SNAPSHOT_MESSAGE, - extra_headers=row.extra_headers) - snapshot_progress_by_topic[row.destination_topic] = row.key_dict + key_ser, value_ser = serializer.serialize_table_data_message(row) + kafka_client.produce(row.destination_topic, key_ser, value_ser, + constants.SINGLE_TABLE_SNAPSHOT_MESSAGE, None, + row.event_db_time, None, + constants.SNAPSHOT_OPERATION_ID, row.extra_headers) + last_row_retrieved = row + if last_row_retrieved: + key_as_dict = dict(zip(t.key_field_names, + last_row_retrieved.ordered_key_field_values)) + snapshot_progress_by_topic[last_row_retrieved.destination_topic] = key_as_dict if t.snapshot_complete: progress_tracker.record_snapshot_progress( t.topic_name, constants.SNAPSHOT_COMPLETION_SENTINEL) snapshot_progress_by_topic.pop(row.destination_topic, None) completions_to_log.append(functools.partial( progress_tracker.log_snapshot_completed, t.topic_name, t.fq_name, - t.key_schema_id, t.value_schema_id, helpers.naive_utcnow(), row.key_dict + helpers.naive_utcnow(), key_as_dict )) snapshots_remain = not all([t.snapshot_complete for t in tables]) elif not lagging_change_tables: @@ -329,18 +331,17 @@ def poll_periodic_tasks() -> bool: f'a bug. Fix it! Prior: {last_produced_row}, current: {row}') last_produced_row = row queued_change_row_counts[row.destination_topic] -= 1 - - kafka_client.produce(row.destination_topic, row.key_dict, row.avro_key_schema_id, - row.value_dict, row.avro_value_schema_id, + key_ser, value_ser = serializer.serialize_table_data_message(row) + kafka_client.produce(row.destination_topic, key_ser, value_ser, constants.SINGLE_TABLE_CHANGE_MESSAGE, - table_to_unified_topics_map.get(row.table_fq_name, []), - extra_headers=row.extra_headers) + table_to_unified_topics_map.get(row.destination_topic, []), + row.event_db_time, row.change_idx.lsn, + row.operation_id, row.extra_headers) last_topic_produces[row.destination_topic] = helpers.naive_utcnow() - if not opts.disable_deletion_tombstones and row.operation_name == \ - constants.DELETE_OPERATION_NAME: - kafka_client.produce(row.destination_topic, row.key_dict, row.avro_key_schema_id, - None, row.avro_value_schema_id, + if not opts.disable_deletion_tombstones and row.operation_id == \ + constants.DELETE_OPERATION_ID: + kafka_client.produce(row.destination_topic, key_ser, None, constants.DELETION_CHANGE_TOMBSTONE_MESSAGE) progress_by_topic[row.destination_topic] = row.change_idx diff --git a/cdc_kafka/metric_reporting/accumulator.py b/cdc_kafka/metric_reporting/accumulator.py index 8411e11..d49a74e 100644 --- a/cdc_kafka/metric_reporting/accumulator.py +++ b/cdc_kafka/metric_reporting/accumulator.py @@ -1,6 +1,6 @@ import abc import datetime -from typing import List, Any, Dict, Optional +from typing import List, Optional import confluent_kafka import pyodbc @@ -28,12 +28,13 @@ def register_sleep(self, sleep_time_seconds: float) -> None: pass def register_db_query(self, seconds_elapsed: float, db_query_kind: str, retrieved_row_count: int) -> None: pass @abc.abstractmethod - def register_kafka_produce(self, seconds_elapsed: float, original_value: Optional[Dict[str, Any]], - message_type: str) -> None: pass + def register_kafka_produce(self, seconds_elapsed: float, message_type: str, + event_datetime: Optional[datetime.datetime] = None, change_lsn: Optional[bytes] = None, + operation_id: Optional[int] = None) -> None: pass @abc.abstractmethod - def kafka_delivery_callback(self, message_type: str, message: confluent_kafka.Message, - original_key: Dict[str, Any], original_value: Optional[Dict[str, Any]]) -> None: pass + def kafka_delivery_callback(self, message: confluent_kafka.Message, + event_datetime: datetime.datetime) -> None: pass class NoopAccumulator(AccumulatorAbstract): @@ -42,10 +43,11 @@ def end_and_get_values(self) -> metrics.Metrics: return metrics.Metrics() def register_sleep(self, sleep_time_seconds: float) -> None: pass def register_db_query(self, seconds_elapsed: float, db_query_kind: str, retrieved_row_count: int) -> None: pass - def register_kafka_produce(self, seconds_elapsed: float, original_value: Optional[Dict[str, Any]], - message_type: str) -> None: pass - def kafka_delivery_callback(self, message_type: str, message: confluent_kafka.Message, - original_key: Dict[str, Any], original_value: Optional[Dict[str, Any]]) -> None: pass + def register_kafka_produce(self, seconds_elapsed: float, message_type: str, + event_datetime: Optional[datetime.datetime] = None, change_lsn: Optional[bytes] = None, + operation_id: Optional[int] = None) -> None: pass + def kafka_delivery_callback(self, message: confluent_kafka.Message, + event_datetime: datetime.datetime) -> None: pass class Accumulator(AccumulatorAbstract): @@ -75,8 +77,8 @@ def reset_and_start(self) -> None: self._db_snapshot_queries_count: int = 0 self._db_snapshot_queries_total_time_sec: float = 0 self._db_snapshot_rows_retrieved_count: int = 0 - self._change_lsns_produced: sortedcontainers.SortedList = sortedcontainers.SortedList() - self._change_db_tran_end_times_produced: sortedcontainers.SortedList = sortedcontainers.SortedList() + self._change_lsns_produced: sortedcontainers.SortedList[bytes] = sortedcontainers.SortedList() + self._change_db_tran_end_times_produced: sortedcontainers.SortedList[datetime.datetime] = sortedcontainers.SortedList() self._e2e_latencies_sec: List[float] = [] self._kafka_produces_total_time_sec: float = 0 self._kafka_delivery_acks_count: int = 0 @@ -115,12 +117,12 @@ def end_and_get_values(self) -> metrics.Metrics: m.interval_delta_sec = interval_delta_sec m.earliest_change_lsn_produced = \ - (self._change_lsns_produced and self._change_lsns_produced[0]) or None + (self._change_lsns_produced and f'0x{self._change_lsns_produced[0].hex()}') or None m.earliest_change_db_tran_end_time_produced = \ (self._change_db_tran_end_times_produced and self._change_db_tran_end_times_produced[0]) \ or None m.latest_change_lsn_produced = \ - (self._change_lsns_produced and self._change_lsns_produced[-1]) or None + (self._change_lsns_produced and f'0x{self._change_lsns_produced[-1].hex()}') or None m.latest_change_db_tran_end_time_produced = \ (self._change_db_tran_end_times_produced and self._change_db_tran_end_times_produced[-1]) \ or None @@ -193,8 +195,9 @@ def register_db_query(self, seconds_elapsed: float, db_query_kind: str, retrieve else: raise Exception(f'Accumulator.register_db_query does not recognize db_query_kind "{db_query_kind}".') - def register_kafka_produce(self, seconds_elapsed: float, original_value: Optional[Dict[str, Any]], - message_type: str) -> None: + def register_kafka_produce(self, seconds_elapsed: float, message_type: str, + event_datetime: Optional[datetime.datetime] = None, change_lsn: Optional[bytes] = None, + operation_id: Optional[int] = None) -> None: self._kafka_produces_total_time_sec += seconds_elapsed if message_type in (constants.CHANGE_PROGRESS_MESSAGE, constants.SNAPSHOT_PROGRESS_MESSAGE, @@ -206,42 +209,35 @@ def register_kafka_produce(self, seconds_elapsed: float, original_value: Optiona self._messages_copied_to_unified_topics += 1 elif message_type == constants.SINGLE_TABLE_SNAPSHOT_MESSAGE: self._produced_snapshot_records_count += 1 - elif message_type == constants.SINGLE_TABLE_CHANGE_MESSAGE and original_value: - self._change_lsns_produced.add(original_value[constants.LSN_NAME]) - self._change_db_tran_end_times_produced.add(original_value[constants.EVENT_TIME_NAME]) - operation_name = original_value[constants.OPERATION_NAME] - if operation_name == constants.DELETE_OPERATION_NAME: + elif message_type == constants.SINGLE_TABLE_CHANGE_MESSAGE: + if change_lsn: + self._change_lsns_produced.add(change_lsn) + if event_datetime: + self._change_db_tran_end_times_produced.add(event_datetime) + if operation_id == constants.DELETE_OPERATION_ID: self._produced_delete_changes_count += 1 - elif operation_name == constants.INSERT_OPERATION_NAME: + elif operation_id == constants.INSERT_OPERATION_ID: self._produced_insert_changes_count += 1 - elif operation_name == constants.POST_UPDATE_OPERATION_NAME: + elif operation_id == constants.POST_UPDATE_OPERATION_ID: self._produced_update_changes_count += 1 else: - raise Exception(f'Accumulator.register_kafka_produce does not recognize operation name: ' - f'"{operation_name}".') + raise Exception(f'Accumulator.register_kafka_produce does not recognize operation ID: ' + f'"{operation_id}".') elif message_type == constants.SNAPSHOT_LOGGING_MESSAGE: pass else: raise Exception(f'Accumulator.register_kafka_produce does not recognize message type: "{message_type}".') - def kafka_delivery_callback(self, message_type: str, message: confluent_kafka.Message, - original_key: Optional[Dict[str, Any]], - original_value: Optional[Dict[str, Any]]) -> None: + def kafka_delivery_callback(self, message: confluent_kafka.Message, + event_datetime: datetime.datetime) -> None: self._kafka_delivery_acks_count += 1 - if message_type not in (constants.SINGLE_TABLE_CHANGE_MESSAGE, constants.UNIFIED_TOPIC_CHANGE_MESSAGE): - return - - if not original_value: - return - timestamp_type, timestamp = message.timestamp() if timestamp_type != confluent_kafka.TIMESTAMP_CREATE_TIME: produce_datetime = helpers.naive_utcnow() else: produce_datetime = datetime.datetime.fromtimestamp(timestamp / 1000.0, datetime.UTC).replace(tzinfo=None) - event_time = datetime.datetime.fromisoformat(original_value[constants.EVENT_TIME_NAME]) - db_commit_time = self._clock_syncer.db_time_to_utc(event_time) + db_commit_time = self._clock_syncer.db_time_to_utc(event_datetime) e2e_latency = (produce_datetime - db_commit_time).total_seconds() self._e2e_latencies_sec.append(e2e_latency) diff --git a/cdc_kafka/metric_reporting/kafka_reporter.py b/cdc_kafka/metric_reporting/kafka_reporter.py index e6e0a85..b463cc6 100644 --- a/cdc_kafka/metric_reporting/kafka_reporter.py +++ b/cdc_kafka/metric_reporting/kafka_reporter.py @@ -1,11 +1,12 @@ import argparse import os +from typing import Type, TypeVar from . import reporter_base from .. import kafka, constants - -from typing import Type, TypeVar from .metrics import Metrics +from ..serializers import SerializerAbstract +from ..serializers.avro import AvroSerializer KafkaReporterType = TypeVar('KafkaReporterType', bound='KafkaReporter') @@ -13,26 +14,18 @@ class KafkaReporter(reporter_base.ReporterBase): DEFAULT_TOPIC = '_cdc_to_kafka_metrics' - def __init__(self, metrics_topic: str) -> None: + def __init__(self, metrics_topic: str, opts: argparse.Namespace) -> None: self._metrics_topic: str = metrics_topic - self._schemas_registered: bool = False - self._metrics_key_schema_id: int = -1 - self._metrics_value_schema_id: int = -1 + self._serializer: SerializerAbstract = AvroSerializer( + opts.schema_registry_url, opts.always_use_avro_longs, opts.progress_topic_name, + opts.snapshot_logging_topic_name, opts.metrics_topic_name, opts.avro_type_spec_overrides, + disable_writes=True) # noinspection PyProtectedMember def emit(self, metrics: 'Metrics') -> None: metrics_dict = metrics.as_dict() - key = {'metrics_namespace': metrics_dict['metrics_namespace']} - - if not self._schemas_registered: - client = kafka.KafkaClient.get_instance() - self._metrics_key_schema_id, self._metrics_value_schema_id = (client.register_schemas( - self._metrics_topic, Metrics.METRICS_AVRO_KEY_SCHEMA, Metrics.METRICS_AVRO_VALUE_SCHEMA)) - self._schemas_registered = True - - kafka.KafkaClient.get_instance().produce( - self._metrics_topic, key, self._metrics_key_schema_id, metrics_dict, self._metrics_value_schema_id, - constants.METRIC_REPORTING_MESSAGE) + key, value = self._serializer.serialize_metrics_message(metrics_dict['metrics_namespace'], metrics_dict) + kafka.KafkaClient.get_instance().produce(self._metrics_topic, key, value, constants.METRIC_REPORTING_MESSAGE) @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> None: @@ -44,4 +37,4 @@ def add_arguments(parser: argparse.ArgumentParser) -> None: @classmethod def construct_with_options(cls: Type[KafkaReporterType], opts: argparse.Namespace) -> KafkaReporterType: metrics_topic: str = opts.kafka_metrics_topic or KafkaReporter.DEFAULT_TOPIC - return cls(metrics_topic) + return cls(metrics_topic, opts) diff --git a/cdc_kafka/metric_reporting/local_file_reporter.py b/cdc_kafka/metric_reporting/local_file_reporter.py index 37ddce9..77dc56f 100644 --- a/cdc_kafka/metric_reporting/local_file_reporter.py +++ b/cdc_kafka/metric_reporting/local_file_reporter.py @@ -3,11 +3,10 @@ import logging import os import pathlib +from typing import TYPE_CHECKING, TypeVar, Type from . import reporter_base -from typing import TYPE_CHECKING, TypeVar, Type - if TYPE_CHECKING: from .metrics import Metrics diff --git a/cdc_kafka/metric_reporting/metrics.py b/cdc_kafka/metric_reporting/metrics.py index b4fa078..59f582c 100644 --- a/cdc_kafka/metric_reporting/metrics.py +++ b/cdc_kafka/metric_reporting/metrics.py @@ -1,10 +1,5 @@ -import json from typing import Any, Dict -import confluent_kafka.avro - -from .. import constants - class Metrics(object): FIELDS_AND_TYPES = [ @@ -58,31 +53,6 @@ class Metrics(object): ] FIELD_NAMES = {ft[0] for ft in FIELDS_AND_TYPES} - METRICS_SCHEMA_VERSION = '2' - - METRICS_AVRO_KEY_SCHEMA = confluent_kafka.avro.loads(json.dumps({ - "name": f"{constants.AVRO_SCHEMA_NAMESPACE}__metrics_v{METRICS_SCHEMA_VERSION}__key", - "namespace": constants.AVRO_SCHEMA_NAMESPACE, - "type": "record", - "fields": [ - { - "name": "metrics_namespace", - "type": "string" - } - ] - })) - - METRICS_AVRO_VALUE_SCHEMA = confluent_kafka.avro.loads(json.dumps({ - "name": f"{constants.AVRO_SCHEMA_NAMESPACE}__metrics_v{METRICS_SCHEMA_VERSION}__value", - "namespace": constants.AVRO_SCHEMA_NAMESPACE, - "type": "record", - "fields": [ - { - "name": k, - "type": v - } for (k, v) in FIELDS_AND_TYPES - ] - })) def __setattr__(self, attr: str, value: Any) -> None: if attr not in Metrics.FIELD_NAMES: diff --git a/cdc_kafka/metric_reporting/reporter_base.py b/cdc_kafka/metric_reporting/reporter_base.py index 716b62c..67ba26d 100644 --- a/cdc_kafka/metric_reporting/reporter_base.py +++ b/cdc_kafka/metric_reporting/reporter_base.py @@ -1,7 +1,6 @@ import argparse import datetime from abc import ABC, abstractmethod - from typing import TYPE_CHECKING, Optional, TypeVar, Type if TYPE_CHECKING: diff --git a/cdc_kafka/metric_reporting/stdout_reporter.py b/cdc_kafka/metric_reporting/stdout_reporter.py index 1ae627a..aa41be8 100644 --- a/cdc_kafka/metric_reporting/stdout_reporter.py +++ b/cdc_kafka/metric_reporting/stdout_reporter.py @@ -1,11 +1,10 @@ import argparse import json import logging +from typing import TYPE_CHECKING, TypeVar, Type from . import reporter_base -from typing import TYPE_CHECKING, TypeVar, Type - if TYPE_CHECKING: from .metrics import Metrics diff --git a/cdc_kafka/options.py b/cdc_kafka/options.py index 132e6a2..5e7bbca 100644 --- a/cdc_kafka/options.py +++ b/cdc_kafka/options.py @@ -3,11 +3,11 @@ import json import os import socket +from typing import Tuple, List, Optional, Callable -from typing import Tuple, List from . import constants, kafka_oauth from .metric_reporting import reporter_base - +from .serializers import SerializerAbstract # String constants for options with discrete choices: CAPTURE_INSTANCE_VERSION_STRATEGY_REGEX = 'regex' @@ -15,6 +15,8 @@ LSN_GAP_HANDLING_RAISE_EXCEPTION = 'raise_exception' LSN_GAP_HANDLING_BEGIN_NEW_SNAPSHOT = 'begin_new_snapshot' LSN_GAP_HANDLING_IGNORE = 'ignore' +NEW_FOLLOW_START_POINT_EARLIEST = 'earliest' +NEW_FOLLOW_START_POINT_LATEST = 'latest' NEW_CAPTURE_INSTANCE_SNAPSHOT_HANDLING_BEGIN_NEW = 'begin_new_snapshot' NEW_CAPTURE_INSTANCE_SNAPSHOT_HANDLING_IGNORE = 'ignore' NEW_CAPTURE_INSTANCE_OVERLAP_HANDLING_REPUBLISH = 'republish_from_new_instance' @@ -32,7 +34,9 @@ def str2bool(v: str) -> bool: raise argparse.ArgumentTypeError('Boolean value expected.') -def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List[reporter_base.ReporterBase]]: +def get_options_and_metrics_reporters( + arg_adder: Optional[Callable[[argparse.ArgumentParser, ], None]] = None) -> Tuple[ + argparse.Namespace, List[reporter_base.ReporterBase], SerializerAbstract]: p = argparse.ArgumentParser() # Required @@ -44,10 +48,6 @@ def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List[report default=os.environ.get('KAFKA_BOOTSTRAP_SERVERS'), help='Host and port for your Kafka cluster, e.g. "localhost:9092"') - p.add_argument('--schema-registry-url', - default=os.environ.get('SCHEMA_REGISTRY_URL'), - help='URL to your Confluent Schema Registry, e.g. "http://localhost:8081"') - p.add_argument('--kafka-transactional-id', default=os.environ.get('KAFKA_TRANSACTIONAL_ID'), help='An identifier of your choosing that should stay stable across restarts of a particularly-' @@ -153,6 +153,18 @@ def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List[report f"the LSN of the latest change published to Kafka. Defaults to " f"`{LSN_GAP_HANDLING_RAISE_EXCEPTION}`") + p.add_argument('--new-follow-start-point', + choices=(NEW_FOLLOW_START_POINT_EARLIEST, NEW_FOLLOW_START_POINT_LATEST), + default=os.environ.get('NEW_FOLLOW_START_POINT', NEW_FOLLOW_START_POINT_LATEST), + help=f"Controls how much change data history to read from SQL Server capture tables, for any tables " + f"that are being followed by this process for the first time. Value " + f"`{NEW_FOLLOW_START_POINT_EARLIEST}` will pull all existing data from the capture tables; " + f"value `{NEW_FOLLOW_START_POINT_EARLIEST}` will only process change data added after this " + f"process starts following the table. Note that use of `{NEW_FOLLOW_START_POINT_EARLIEST}` " + f"with unified topics may lead to LSN regressions in the sequence of unified topic messages " + f"in the case where new tables are added to a previously-tracked set. This setting does not " + f"affect the behavior of table snapshots. Defaults to `{NEW_FOLLOW_START_POINT_LATEST}`") + p.add_argument('--unified-topics', default=os.environ.get('UNIFIED_TOPICS', {}), type=json.loads, help=f'A string that is a JSON object mapping topic names to various configuration parameters as ' @@ -202,6 +214,12 @@ def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List[report help="Runs count validations between messages in the Kafka topic and rows in the change and " "source tables, then quits. Respects the table inclusion/exclusion regexes.") + p.add_argument('--message-serializer', + default=os.environ.get('MESSAGE_SERIALIZER', + 'cdc_kafka.serializers.avro.AvroSerializer'), + help="The serializer class (from this project's `serializers` module) used to serialize messages" + "sent to Kafka.") + p.add_argument('--metrics-reporters', default=os.environ.get('METRICS_REPORTERS', 'cdc_kafka.metric_reporting.stdout_reporter.StdoutReporter'), @@ -235,20 +253,13 @@ def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List[report p.add_argument('--truncate-fields', default=os.environ.get('TRUNCATE_FIELDS', {}), type=json.loads, help='Optional JSON object that maps schema.table.column names to an integer max number of ' - 'characters that should be copied into the Kafka message for that field\'s values. The schema, ' - 'table, and column names are case-insensitive. Example: `{"dbo.order.gift_note": 65536}`. ' - 'Note that this truncates based on _character_ length, not _byte_ length!') - - p.add_argument('--avro-type-spec-overrides', - default=os.environ.get('AVRO_TYPE_SPEC_OVERRIDES', {}), type=json.loads, - help='Optional JSON object that maps schema.table.column names to a string or object indicating the ' - 'Avro schema type specification you want to use for the field. This will override the default ' - 'mapping of SQL types to Avro types otherwise used and found in avro.py. Note that setting ' - 'this only changes the generated schema and will NOT affect the way values are passed to the ' - 'Avro serialization library, so any overriding type specified should be compatible with the ' - 'SQL/Python types of the actual data. Example: `{"dbo.order.orderid": "long"}` could be used ' - 'to specify the use of an Avro `long` type for a source DB column that is only a 32-bit INT, ' - 'perhaps in preparation for a future DB column change.') + 'UTF-8 encoded bytes that should be serialized into the Kafka message for that field\'s ' + 'values. Only applicable to string types; will raise an exception if used for non-strings. ' + 'Truncation respects UTF-8 character boundaries and will not break in the middle of 2- or ' + '4-byte characters. The schema, table, and column names are case-insensitive. Example: ' + '`{"dbo.order.gift_note": 65536}`. When a field is truncated via this mechanism, a Kafka ' + 'message header of the form key: `cdc_to_kafka_truncated_field__`, value ' + '`,` will be added to the message.') p.add_argument('--terminate-on-capture-instance-change', type=str2bool, nargs='?', const=True, @@ -268,25 +279,14 @@ def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List[report "then exits without changing any state. Can be handy for validating other configuration such " "as the regexes used to control which tables are followed and/or snapshotted.") - p.add_argument('--always-use-avro-longs', - type=str2bool, nargs='?', const=True, - default=str2bool(os.environ.get('ALWAYS_USE_AVRO_LONGS', '0')), - help="Defaults to False. If set to True, Avro schemas produced/registered by this process will " - "use the Avro `long` type instead of the `int` type for fields corresponding to SQL Server " - "INT, SMALLINT, or TINYINT columns. This can be used to future-proof in cases where the column " - "size may need to be upgraded in the future, at the potential cost of increased storage or " - "memory space needs in consuming processes. Note that if this change is made for existing " - "topics, the schema registration attempt will violate Avro FORWARD compatibility checks (the " - "default used by this process), meaning that you may need to manually override the schema " - "registry compatibility level for any such topics first.") - p.add_argument('--db-row-batch-size', type=int, default=os.environ.get('DB_ROW_BATCH_SIZE', 2000), help="Maximum number of rows to retrieve in a single change data or snapshot query. Default 2000.") kafka_oauth.add_kafka_oauth_arg(p) - + if arg_adder: + arg_adder(p) opts, _ = p.parse_known_args() reporter_classes: List[reporter_base.ReporterBase] = [] @@ -305,4 +305,12 @@ def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List[report for reporter_class in reporter_classes: reporters.append(reporter_class.construct_with_options(opts)) - return opts, reporters + package_module, class_name = opts.message_serializer.rsplit('.', 1) + module = importlib.import_module(package_module) + serializer_class: SerializerAbstract = getattr(module, class_name) + serializer_class.add_arguments(p) + opts, _ = p.parse_known_args() + disable_writes: bool = opts.run_validations or opts.report_progress_only + serializer = serializer_class.construct_with_options(opts, disable_writes) + + return opts, reporters, serializer diff --git a/cdc_kafka/parsed_row.py b/cdc_kafka/parsed_row.py index cb80a83..da59345 100644 --- a/cdc_kafka/parsed_row.py +++ b/cdc_kafka/parsed_row.py @@ -1,30 +1,25 @@ import datetime -from typing import Any, Dict, Sequence, Optional +from typing import Any, Sequence, List, Optional, Dict from . import change_index class ParsedRow(object): - __slots__ = 'table_fq_name', 'row_kind', 'operation_name', 'event_db_time', 'change_idx', \ - 'ordered_key_field_values', 'destination_topic', 'avro_key_schema_id', 'avro_value_schema_id', 'key_dict', \ - 'value_dict', 'extra_headers' + __slots__ = ('destination_topic', 'operation_id', 'cdc_update_mask', 'event_db_time', + 'change_idx', 'ordered_key_field_values', 'table_data_cols', 'extra_headers') - def __init__(self, table_fq_name: str, row_kind: str, operation_name: str, event_db_time: datetime.datetime, - change_idx: change_index.ChangeIndex, ordered_key_field_values: Sequence[Any], destination_topic: str, - avro_key_schema_id: int, avro_value_schema_id: int, key_dict: Dict[str, Any], - value_dict: Dict[str, Any], extra_headers: Optional[Dict[str, str | bytes]] = None) -> None: - self.table_fq_name: str = table_fq_name - self.row_kind: str = row_kind - self.operation_name: str = operation_name + def __init__(self, destination_topic: str, operation_id: int, cdc_update_mask: bytes, + event_db_time: datetime.datetime, change_idx: change_index.ChangeIndex, + ordered_key_field_values: Sequence[Any], table_data_cols: List[Any], + extra_headers: Optional[Dict[str, str | bytes]] = None) -> None: + self.destination_topic: str = destination_topic + self.operation_id: int = operation_id + self.cdc_update_mask: bytes = cdc_update_mask self.event_db_time: datetime.datetime = event_db_time self.change_idx: change_index.ChangeIndex = change_idx self.ordered_key_field_values: Sequence[Any] = ordered_key_field_values - self.destination_topic: str = destination_topic - self.avro_key_schema_id: int = avro_key_schema_id - self.avro_value_schema_id: int = avro_value_schema_id - self.key_dict: Dict[str, Any] = key_dict - self.value_dict: Dict[str, Any] = value_dict + self.table_data_cols: List[Any] = table_data_cols self.extra_headers: Optional[Dict[str, str | bytes]] = extra_headers def __repr__(self) -> str: - return f'ParsedRow from {self.table_fq_name} of kind {self.row_kind}, change index {self.change_idx}' + return f'ParsedRow for topic {self.destination_topic}, change index {self.change_idx}' diff --git a/cdc_kafka/progress_reset_tool.py b/cdc_kafka/progress_reset_tool.py index 7fa47a6..434d7e8 100755 --- a/cdc_kafka/progress_reset_tool.py +++ b/cdc_kafka/progress_reset_tool.py @@ -1,39 +1,26 @@ import argparse -import json import logging import os import socket -from cdc_kafka import kafka, constants, progress_tracking, options, kafka_oauth +from . import kafka, constants, progress_tracking, options from .metric_reporting import accumulator logger = logging.getLogger(__name__) def main() -> None: - p = argparse.ArgumentParser() - p.add_argument('--topic-names', required=True, - default=os.environ.get('TOPIC_NAMES')) - p.add_argument('--progress-kind', required=True, - choices=(constants.CHANGE_ROWS_KIND, constants.ALL_PROGRESS_KINDS, constants.SNAPSHOT_ROWS_KIND), - default=os.environ.get('PROGRESS_KIND')) - p.add_argument('--schema-registry-url', required=True, - default=os.environ.get('SCHEMA_REGISTRY_URL')) - p.add_argument('--kafka-bootstrap-servers', required=True, - default=os.environ.get('KAFKA_BOOTSTRAP_SERVERS')) - p.add_argument('--progress-topic-name', required=True, - default=os.environ.get('PROGRESS_TOPIC_NAME')) - p.add_argument('--snapshot-logging-topic-name', required=True, - default=os.environ.get('SNAPSHOT_LOGGING_TOPIC_NAME')) - p.add_argument('--execute', - type=options.str2bool, nargs='?', const=True, - default=options.str2bool(os.environ.get('EXECUTE', '0'))) - p.add_argument('--extra-kafka-producer-config', - default=os.environ.get('EXTRA_KAFKA_PRODUCER_CONFIG', {}), type=json.loads) - p.add_argument('--extra-kafka-consumer-config', - default=os.environ.get('EXTRA_KAFKA_CONSUMER_CONFIG', {}), type=json.loads) - kafka_oauth.add_kafka_oauth_arg(p) - opts, _ = p.parse_known_args() + def add_args(p: argparse.ArgumentParser) -> None: + p.add_argument('--topic-names', required=True, + default=os.environ.get('TOPIC_NAMES')) + p.add_argument('--progress-kind', required=True, + choices=(constants.CHANGE_ROWS_KIND, constants.ALL_PROGRESS_KINDS, constants.SNAPSHOT_ROWS_KIND), + default=os.environ.get('PROGRESS_KIND')) + p.add_argument('--execute', + type=options.str2bool, nargs='?', const=True, + default=options.str2bool(os.environ.get('EXECUTE', '0'))) + + opts, _, serializer = options.get_options_and_metrics_reporters(add_args) logger.info(f""" @@ -48,11 +35,11 @@ def main() -> None: """) - with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers, opts.schema_registry_url, + with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers, opts.extra_kafka_consumer_config, opts.extra_kafka_producer_config, disable_writing=True) as kafka_client: - progress_tracker = progress_tracking.ProgressTracker(kafka_client, opts.progress_topic_name, socket.getfqdn(), - opts.snapshot_logging_topic_name) + progress_tracker = progress_tracking.ProgressTracker(kafka_client, serializer, opts.progress_topic_name, + socket.getfqdn(), opts.snapshot_logging_topic_name) progress_entries = progress_tracker.get_prior_progress() def act(topic: str, progress_kind: str) -> None: diff --git a/cdc_kafka/progress_topic_validator.py b/cdc_kafka/progress_topic_validator.py index 57e2c98..beb13ba 100644 --- a/cdc_kafka/progress_topic_validator.py +++ b/cdc_kafka/progress_topic_validator.py @@ -2,16 +2,15 @@ import collections import copy import datetime -import json import logging import os import re from typing import Dict, Optional, Set -import confluent_kafka from tabulate import tabulate -from cdc_kafka import kafka, constants, progress_tracking, options, helpers, kafka_oauth +from . import kafka, constants, progress_tracking, options, helpers +from .serializers import DeserializedMessage from .metric_reporting import accumulator logger = logging.getLogger(__name__) @@ -21,8 +20,8 @@ class TopicProgressInfo(object): def __init__(self) -> None: self.change_progress_count: int = 0 self.snapshot_progress_count: int = 0 - self.last_change_progress: Optional[confluent_kafka.Message] = None - self.last_snapshot_progress: Optional[confluent_kafka.Message] = None + self.last_change_progress: Optional[DeserializedMessage] = None + self.last_snapshot_progress: Optional[DeserializedMessage] = None self.distinct_change_tables: Set[str] = set() self.reset_count: int = 0 self.evolution_count: int = 0 @@ -31,31 +30,19 @@ def __init__(self) -> None: def main() -> None: - p = argparse.ArgumentParser() - p.add_argument('--topics-to-include-regex', - default=os.environ.get('TOPICS_TO_INCLUDE_REGEX', '.*')) - p.add_argument('--topics-to-exclude-regex', - default=os.environ.get('TOPICS_TO_EXCLUDE_REGEX')) - p.add_argument('--schema-registry-url', - default=os.environ.get('SCHEMA_REGISTRY_URL')) - p.add_argument('--kafka-bootstrap-servers', - default=os.environ.get('KAFKA_BOOTSTRAP_SERVERS')) - p.add_argument('--progress-topic-name', - default=os.environ.get('PROGRESS_TOPIC_NAME', '_cdc_to_kafka_progress')) - p.add_argument('--extra-kafka-consumer-config', - default=os.environ.get('EXTRA_KAFKA_CONSUMER_CONFIG', {}), type=json.loads) - kafka_oauth.add_kafka_oauth_arg(p) - p.add_argument('--show-all', - type=options.str2bool, nargs='?', const=True, - default=options.str2bool(os.environ.get('SHOW_ALL', '0'))) - opts, _ = p.parse_known_args() - - if not (opts.schema_registry_url and opts.kafka_bootstrap_servers): - raise Exception('Arguments schema_registry_url and kafka_bootstrap_servers are required.') + def add_args(p: argparse.ArgumentParser) -> None: + p.add_argument('--topics-to-include-regex', + default=os.environ.get('TOPICS_TO_INCLUDE_REGEX', '.*')) + p.add_argument('--topics-to-exclude-regex', + default=os.environ.get('TOPICS_TO_EXCLUDE_REGEX')) + p.add_argument('--show-all', + type=options.str2bool, nargs='?', const=True, + default=options.str2bool(os.environ.get('SHOW_ALL', '0'))) + + opts, _, serializer = options.get_options_and_metrics_reporters(add_args) with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers, - opts.schema_registry_url, opts.extra_kafka_consumer_config, {}, - disable_writing=True) as kafka_client: + opts.extra_kafka_consumer_config, {}, disable_writing=True) as kafka_client: if kafka_client.get_topic_partition_count(opts.progress_topic_name) is None: logger.error('Progress topic %s not found.', opts.progress_topic_name) exit(1) @@ -78,8 +65,11 @@ def main() -> None: logger.info('Read %s messages so far; last was %s', msg_ctr, helpers.format_coordinates(msg)) # noinspection PyTypeChecker - msg_key = dict(msg.key()) - topic, kind = msg_key['topic_name'], msg_key['progress_kind'] + deser_msg = serializer.deserialize(msg) + if deser_msg.key_dict is None: + continue + + topic, kind = deser_msg.key_dict['topic_name'], deser_msg.key_dict['progress_kind'] if not topic_include_re.match(topic): continue @@ -89,22 +79,22 @@ def main() -> None: prior = copy.copy(topic_info.get(topic)) # noinspection PyArgumentList - if msg.value() is None: + if not deser_msg.value_dict: logger.warning('%s progress for topic %s reset at %s', kind, topic, helpers.format_coordinates(msg)) topic_info[topic].reset_count += 1 continue # noinspection PyTypeChecker,PyArgumentList - current_change_table = dict(msg.value())['change_table_name'] + current_change_table = deser_msg.value_dict['change_table_name'] topic_info[topic].distinct_change_tables.add(current_change_table) - current_pe = progress_tracking.ProgressEntry.from_message(msg) + current_pe = progress_tracking.ProgressEntry.from_message(deser_msg) if kind == constants.CHANGE_ROWS_KIND: if not current_pe.change_index: raise Exception('Unexpected state.') current_change_index = current_pe.change_index topic_info[topic].change_progress_count += 1 - topic_info[topic].last_change_progress = msg + topic_info[topic].last_change_progress = deser_msg if current_change_index.is_probably_heartbeat: topic_info[topic].heartbeat_count += 1 @@ -117,7 +107,7 @@ def main() -> None: current_change_index.is_probably_heartbeat: topic_info[topic].problem_count += 1 logger.warning('Duplicate change entry for topic %s between %s and %s', topic, - helpers.format_coordinates(prior.last_change_progress), + helpers.format_coordinates(prior.last_change_progress.raw_msg), helpers.format_coordinates(msg)) if prior_change_index > current_change_index: topic_info[topic].problem_count += 1 @@ -126,7 +116,7 @@ def main() -> None: Prior : progress message %s, index %s Current: progress message %s, index %s ''' - logger.error(log_msg, topic, helpers.format_coordinates(prior.last_change_progress), + logger.error(log_msg, topic, helpers.format_coordinates(prior.last_change_progress.raw_msg), prior_change_index, helpers.format_coordinates(msg), current_change_index) if kind == constants.SNAPSHOT_ROWS_KIND: @@ -134,7 +124,7 @@ def main() -> None: raise Exception('Unexpected state.') current_snapshot_index = current_pe.snapshot_index topic_info[topic].snapshot_progress_count += 1 - topic_info[topic].last_snapshot_progress = msg + topic_info[topic].last_snapshot_progress = deser_msg if prior and prior.last_snapshot_progress: prior_pe = progress_tracking.ProgressEntry.from_message(prior.last_snapshot_progress) @@ -157,7 +147,8 @@ def main() -> None: Prior : progress message %s, index %s Current: progress message %s, index %s ''' - logger.error(log_msg, topic, helpers.format_coordinates(prior.last_snapshot_progress), + logger.error(log_msg, topic, + helpers.format_coordinates(prior.last_snapshot_progress.raw_msg), prior_pe.snapshot_index, helpers.format_coordinates(msg), current_pe.snapshot_index) @@ -178,12 +169,13 @@ def main() -> None: table = [[k, v.change_progress_count, v.snapshot_progress_count, - 'yes' if (progress_tracking.ProgressEntry.from_message(v.last_snapshot_progress).snapshot_index == + 'yes' if (v.last_snapshot_progress and + progress_tracking.ProgressEntry.from_message(v.last_snapshot_progress).snapshot_index == constants.SNAPSHOT_COMPLETION_SENTINEL) else 'no', len(v.distinct_change_tables), - datetime.datetime.fromtimestamp(v.last_snapshot_progress.timestamp()[1] / 1000, + datetime.datetime.fromtimestamp(v.last_snapshot_progress.raw_msg.timestamp()[1] / 1000, datetime.UTC) if v.last_snapshot_progress else None, - datetime.datetime.fromtimestamp(v.last_change_progress.timestamp()[1] / 1000, + datetime.datetime.fromtimestamp(v.last_change_progress.raw_msg.timestamp()[1] / 1000, datetime.UTC) if v.last_change_progress else None, v.reset_count, v.problem_count, diff --git a/cdc_kafka/progress_tracking.py b/cdc_kafka/progress_tracking.py index a3aad31..9de8fff 100644 --- a/cdc_kafka/progress_tracking.py +++ b/cdc_kafka/progress_tracking.py @@ -1,192 +1,36 @@ import datetime -import json import logging -from typing import Dict, Tuple, Any, Optional, Mapping, TypeVar, Type - -import confluent_kafka.avro +from typing import Dict, Tuple, Any, Optional, Mapping, TypeVar, Type, TYPE_CHECKING from . import constants, helpers, tracked_tables from .change_index import ChangeIndex +from .serializers import SerializerAbstract, DeserializedMessage -from typing import TYPE_CHECKING if TYPE_CHECKING: from .kafka import KafkaClient import confluent_kafka logger = logging.getLogger(__name__) -PROGRESS_TRACKING_SCHEMA_VERSION = '2' -PROGRESS_TRACKING_AVRO_KEY_SCHEMA = confluent_kafka.avro.loads(json.dumps({ - "name": f"{constants.AVRO_SCHEMA_NAMESPACE}__progress_tracking_v{PROGRESS_TRACKING_SCHEMA_VERSION}__key", - "namespace": constants.AVRO_SCHEMA_NAMESPACE, - "type": "record", - "fields": [ - { - "name": "topic_name", - "type": "string" - }, - { - "name": "progress_kind", - "type": { - "type": "enum", - "name": "progress_kind", - "symbols": [ - constants.CHANGE_ROWS_KIND, - constants.SNAPSHOT_ROWS_KIND - ] - } - } - ] -})) -PROGRESS_TRACKING_AVRO_VALUE_SCHEMA = confluent_kafka.avro.loads(json.dumps({ - "name": f"{constants.AVRO_SCHEMA_NAMESPACE}__progress_tracking_v{PROGRESS_TRACKING_SCHEMA_VERSION}__value", - "namespace": constants.AVRO_SCHEMA_NAMESPACE, - "type": "record", - "fields": [ - { - "name": "source_table_name", - "type": "string" - }, - { - "name": "change_table_name", - "type": "string" - }, - # ------------------------------------------------------------------------------------------------ - # These next two are defunct/deprecated as of v4 but remain here to ease the upgrade transition - # for anyone with existing progress recorded by earlier versions: - { - "name": "last_ack_partition", - "type": ["null", "int"] - }, - { - "name": "last_ack_offset", - "type": ["null", "long"] - }, - # ------------------------------------------------------------------------------------------------ - { - "name": "last_ack_position", - "type": [ - { - "type": "record", - "name": f"{constants.CHANGE_ROWS_KIND}_progress", - "namespace": constants.AVRO_SCHEMA_NAMESPACE, - "fields": [ - { - "name": constants.LSN_NAME, - "type": "string", - }, - { - "name": constants.SEQVAL_NAME, - "type": "string", - }, - { - "name": constants.OPERATION_NAME, - "type": { - "type": "enum", - "name": constants.OPERATION_NAME, - "symbols": list(constants.CDC_OPERATION_NAME_TO_ID.keys()) - } - } - ] - }, - { - "type": "record", - "name": f"{constants.SNAPSHOT_ROWS_KIND}_progress", - "namespace": constants.AVRO_SCHEMA_NAMESPACE, - "fields": [ - { - "name": "key_fields", - "type": { - "type": "map", - "values": ["string", "long"] - } - } - ] - } - ] - } - ] -})) - -SNAPSHOT_LOGGING_SCHEMA_VERSION = '1' -SNAPSHOT_LOGGING_AVRO_VALUE_SCHEMA = confluent_kafka.avro.loads(json.dumps({ - "name": f"{constants.AVRO_SCHEMA_NAMESPACE}__snapshot_logging_v{SNAPSHOT_LOGGING_SCHEMA_VERSION}__value", - "namespace": constants.AVRO_SCHEMA_NAMESPACE, - "type": "record", - "fields": [ - { - "name": "topic_name", - "type": "string" - }, - { - "name": "table_name", - "type": "string" - }, - { - "name": "action", - "type": "string" - }, - { - "name": "process_hostname", - "type": "string" - }, - { - "name": "event_time_utc", - "type": "string" - }, - { - "name": "key_schema_id", - "type": ["null", "long"] - }, - { - "name": "value_schema_id", - "type": ["null", "long"] - }, - { - "name": "partition_watermarks_low", - "type": ["null", { - "type": "map", - "values": "long" - }] - }, - { - "name": "partition_watermarks_high", - "type": ["null", { - "type": "map", - "values": "long" - }] - }, - { - "name": "starting_snapshot_index", - "type": ["null", { - "type": "map", - "values": ["string", "long"] - }] - }, - { - "name": "ending_snapshot_index", - "type": ["null", { - "type": "map", - "values": ["string", "long"] - }] - } - ] -})) ProgressEntryType = TypeVar('ProgressEntryType', bound='ProgressEntry') class ProgressEntry(object): @classmethod - def from_message(cls: Type[ProgressEntryType], message: 'confluent_kafka.Message') -> ProgressEntryType: + def from_message(cls: Type[ProgressEntryType], message: DeserializedMessage) -> ProgressEntryType: # noinspection PyTypeChecker,PyArgumentList - k, v = dict(message.key()), dict(message.value()) + k, v = message.key_dict, message.value_dict + + if k is None or v is None: + raise Exception("Malformed message received by ProgressEntry.from_message") + kind: str = k['progress_kind'] if kind not in (constants.CHANGE_ROWS_KIND, constants.SNAPSHOT_ROWS_KIND): raise Exception(f"Unrecognized progress kind from message: {kind}") - msg_coordinates = helpers.format_coordinates(message) + msg_coordinates = helpers.format_coordinates(message.raw_msg) if kind == constants.SNAPSHOT_ROWS_KIND: return cls(kind, k['topic_name'], v['source_table_name'], v['change_table_name'], @@ -194,7 +38,7 @@ def from_message(cls: Type[ProgressEntryType], message: 'confluent_kafka.Message else: return cls(kind, k['topic_name'], v['source_table_name'], v['change_table_name'], - None, ChangeIndex.from_avro_ready_dict(v['last_ack_position']), msg_coordinates) + None, ChangeIndex.from_dict(v['last_ack_position']), msg_coordinates) def __init__(self, progress_kind: str, topic_name: str, source_table_name: str, change_table_name: str, snapshot_index: Optional[Mapping[str, str | int]] = None, @@ -218,10 +62,12 @@ def key(self) -> Dict[str, str]: } @property - def value(self) -> Dict[str, Any]: + def value(self) -> Optional[Dict[str, Any]]: + if not (self.change_index or self.snapshot_index): + return None pos: Dict[str, Any] if self.change_index: - pos = self.change_index.to_avro_ready_dict() + pos = self.change_index.as_dict() else: pos = {'key_fields': self.snapshot_index} return { @@ -238,22 +84,16 @@ def __repr__(self) -> str: class ProgressTracker(object): _instance = None - def __init__(self, kafka_client: 'KafkaClient', progress_topic_name: str, process_hostname: str, - snapshot_logging_topic_name: Optional[str] = None) -> None: + def __init__(self, kafka_client: 'KafkaClient', serializer: SerializerAbstract, progress_topic_name: str, + process_hostname: str, snapshot_logging_topic_name: Optional[str] = None) -> None: if ProgressTracker._instance is not None: raise Exception('ProgressTracker class should be used as a singleton.') self._kafka_client: 'KafkaClient' = kafka_client + self._serializer: SerializerAbstract = serializer self._progress_topic_name: str = progress_topic_name self._process_hostname: str = process_hostname self._snapshot_logging_topic_name: Optional[str] = snapshot_logging_topic_name - self._progress_key_schema_id, self._progress_value_schema_id = kafka_client.register_schemas( - progress_topic_name, PROGRESS_TRACKING_AVRO_KEY_SCHEMA, PROGRESS_TRACKING_AVRO_VALUE_SCHEMA) - if snapshot_logging_topic_name: - _, self._snapshot_logging_schema_id = kafka_client.register_schemas( - snapshot_logging_topic_name, None, SNAPSHOT_LOGGING_AVRO_VALUE_SCHEMA) - else: - self._snapshot_logging_schema_id = 0 self._last_recorded_progress_by_topic: Dict[str, ProgressEntry] = {} self._topic_to_source_table_map: Dict[str, str] = {} self._topic_to_change_table_map: Dict[str, str] = {} @@ -277,12 +117,12 @@ def record_changes_progress(self, topic_name: str, change_index: ChangeIndex) -> change_index=change_index ) + key, value = self._serializer.serialize_progress_tracking_message(progress_entry) + self._kafka_client.produce( topic=self._progress_topic_name, - key=progress_entry.key, - key_schema_id=self._progress_key_schema_id, - value=progress_entry.value, - value_schema_id=self._progress_value_schema_id, + key=key, + value=value, message_type=constants.CHANGE_PROGRESS_MESSAGE ) @@ -297,17 +137,16 @@ def record_snapshot_progress(self, topic_name: str, snapshot_index: Mapping[str, snapshot_index=snapshot_index ) + key, value = self._serializer.serialize_progress_tracking_message(progress_entry) + self._kafka_client.produce( topic=self._progress_topic_name, - key=progress_entry.key, - key_schema_id=self._progress_key_schema_id, - value=progress_entry.value, - value_schema_id=self._progress_value_schema_id, + key=key, + value=value, message_type=constants.SNAPSHOT_PROGRESS_MESSAGE ) def _log_snapshot_event(self, topic_name: str, table_name: str, action: str, - key_schema_id: Optional[int] = None, value_schema_id: Optional[int] = None, event_time: Optional[datetime.datetime] = None, starting_snapshot_index: Optional[Mapping[str, str | int]] = None, ending_snapshot_index: Optional[Mapping[str, str | int]] = None) -> None: @@ -332,38 +171,31 @@ def _log_snapshot_event(self, topic_name: str, table_name: str, action: str, "process_hostname": self._process_hostname, "starting_snapshot_index": starting_snapshot_index, "table_name": table_name, - "topic_name": topic_name, - "key_schema_id": key_schema_id, - "value_schema_id": value_schema_id + "topic_name": topic_name } logger.debug('Logging snapshot event: %s', msg) - + _, value = self._serializer.serialize_snapshot_logging_message(msg) self._kafka_client.produce( topic=self._snapshot_logging_topic_name, key=None, - key_schema_id=0, - value=msg, - value_schema_id=self._snapshot_logging_schema_id, + value=value, message_type=constants.SNAPSHOT_LOGGING_MESSAGE ) - def log_snapshot_started(self, topic_name: str, table_name: str, key_schema_id: int, value_schema_id: int, + def log_snapshot_started(self, topic_name: str, table_name: str, starting_snapshot_index: Mapping[str, str | int]) -> None: return self._log_snapshot_event(topic_name, table_name, constants.SNAPSHOT_LOG_ACTION_STARTED, - key_schema_id=key_schema_id, value_schema_id=value_schema_id, starting_snapshot_index=starting_snapshot_index) - def log_snapshot_resumed(self, topic_name: str, table_name: str, key_schema_id: int, value_schema_id: int, + def log_snapshot_resumed(self, topic_name: str, table_name: str, starting_snapshot_index: Mapping[str, str | int]) -> None: return self._log_snapshot_event(topic_name, table_name, constants.SNAPSHOT_LOG_ACTION_RESUMED, - key_schema_id=key_schema_id, value_schema_id=value_schema_id, starting_snapshot_index=starting_snapshot_index) - def log_snapshot_completed(self, topic_name: str, table_name: str, key_schema_id: int, value_schema_id: int, - event_time: datetime.datetime, ending_snapshot_index: Mapping[str, str | int]) -> None: + def log_snapshot_completed(self, topic_name: str, table_name: str, event_time: datetime.datetime, + ending_snapshot_index: Mapping[str, str | int]) -> None: return self._log_snapshot_event(topic_name, table_name, constants.SNAPSHOT_LOG_ACTION_COMPLETED, - key_schema_id=key_schema_id, value_schema_id=value_schema_id, event_time=event_time, ending_snapshot_index=ending_snapshot_index) def log_snapshot_progress_reset(self, topic_name: str, table_name: str, is_auto_reset: bool, @@ -395,39 +227,28 @@ def maybe_create_snapshot_logging_topic(self) -> None: # the keys in the returned dictionary are tuples of (topic_name, progress_kind) def get_prior_progress(self) -> Dict[Tuple[str, str], ProgressEntry]: - result: Dict[Tuple[str, str], ProgressEntry] = {} - messages: Dict[Tuple[str, str], confluent_kafka.Message] = {} + raw_msgs: Dict[bytes | str | None, confluent_kafka.Message] = {} progress_msg_ctr = 0 for progress_msg in self._kafka_client.consume_all(self._progress_topic_name): progress_msg_ctr += 1 - # noinspection PyTypeChecker - msg_key = dict(progress_msg.key()) - result_key = (msg_key['topic_name'], msg_key['progress_kind']) - # noinspection PyArgumentList if progress_msg.value() is None: - if result_key in result: - del result[result_key] + if progress_msg.key() is not None and progress_msg.key() in raw_msgs: + del raw_msgs[progress_msg.key()] continue - - curr_entry = ProgressEntry.from_message(message=progress_msg) - prior_entry = result.get(result_key) - if (prior_entry and prior_entry.change_index and curr_entry and curr_entry.change_index and - prior_entry.change_index > curr_entry.change_index): - prior_message = messages[result_key] - logger.error( - 'WARNING: Progress topic %s contains unordered entries for %s! Prior: p%s:o%s (%s), ' - 'pos %s; Current: p%s:o%s (%s), pos %s', self._progress_topic_name, result_key, - prior_message.partition(), prior_message.offset(), - datetime.datetime.fromtimestamp(prior_message.timestamp()[1] / 1000, datetime.UTC), - prior_entry.change_index, progress_msg.partition(), progress_msg.offset(), - datetime.datetime.fromtimestamp(progress_msg.timestamp()[1] / 1000, datetime.UTC), - curr_entry.change_index) - result[result_key] = curr_entry # last read for a given key will win - messages[result_key] = progress_msg + raw_msgs[progress_msg.key()] = progress_msg logger.info('Read %s prior progress messages from Kafka topic %s', progress_msg_ctr, self._progress_topic_name) + + result: Dict[Tuple[str, str], ProgressEntry] = {} + for msg in raw_msgs.values(): + deser_msg = self._serializer.deserialize(msg) + if deser_msg.key_dict is None: + raise Exception('Unexpected state: None value from deserializing progress message key') + result[(deser_msg.key_dict['topic_name'], deser_msg.key_dict['progress_kind'])] = \ + ProgressEntry.from_message(message=deser_msg) + return result def reset_progress(self, topic_name: str, kind_to_reset: str, source_table_name: str, is_auto_reset: bool, @@ -436,22 +257,26 @@ def reset_progress(self, topic_name: str, kind_to_reset: str, source_table_name: matched = False if kind_to_reset in (constants.CHANGE_ROWS_KIND, constants.ALL_PROGRESS_KINDS): - key = { - 'topic_name': topic_name, - 'progress_kind': constants.CHANGE_ROWS_KIND, - } - self._kafka_client.produce(self._progress_topic_name, key, self._progress_key_schema_id, None, - self._progress_value_schema_id, constants.PROGRESS_DELETION_TOMBSTONE_MESSAGE) + progress_entry = ProgressEntry(constants.CHANGE_ROWS_KIND, topic_name, '', '') + key, _ = self._serializer.serialize_progress_tracking_message(progress_entry) + self._kafka_client.produce( + topic=self._progress_topic_name, + key=key, + value=None, + message_type=constants.PROGRESS_DELETION_TOMBSTONE_MESSAGE + ) logger.info('Deleted existing change rows progress records for topic %s.', topic_name) matched = True if kind_to_reset in (constants.SNAPSHOT_ROWS_KIND, constants.ALL_PROGRESS_KINDS): - key = { - 'topic_name': topic_name, - 'progress_kind': constants.SNAPSHOT_ROWS_KIND, - } - self._kafka_client.produce(self._progress_topic_name, key, self._progress_key_schema_id, None, - self._progress_value_schema_id, constants.PROGRESS_DELETION_TOMBSTONE_MESSAGE) + progress_entry = ProgressEntry(constants.SNAPSHOT_ROWS_KIND, topic_name, '', '') + key, _ = self._serializer.serialize_progress_tracking_message(progress_entry) + self._kafka_client.produce( + topic=self._progress_topic_name, + key=key, + value=None, + message_type=constants.PROGRESS_DELETION_TOMBSTONE_MESSAGE + ) logger.info('Deleted existing snapshot progress records for topic %s.', topic_name) self.maybe_create_snapshot_logging_topic() self.log_snapshot_progress_reset(topic_name, source_table_name, is_auto_reset, diff --git a/cdc_kafka/replayer.py b/cdc_kafka/replayer.py index 220c7d4..f120c19 100755 --- a/cdc_kafka/replayer.py +++ b/cdc_kafka/replayer.py @@ -44,14 +44,15 @@ import time from datetime import datetime, UTC from multiprocessing.synchronize import Event as EventClass -from typing import Set, Any, List, Dict, Tuple, NamedTuple +from typing import Set, Any, List, Dict, Tuple, NamedTuple, Optional -import ctds +import ctds # type: ignore[import-untyped] from confluent_kafka import Consumer, KafkaError, TopicPartition, Message, OFFSET_BEGINNING +from confluent_kafka.admin import TopicMetadata from confluent_kafka.serialization import SerializationContext, MessageField from confluent_kafka.schema_registry import SchemaRegistryClient from confluent_kafka.schema_registry.avro import AvroDeserializer -from faster_fifo import Queue +from faster_fifo import Queue # type: ignore[import-not-found] from cdc_kafka import kafka_oauth @@ -246,6 +247,8 @@ def main() -> None: default=os.environ.get('TARGET_DB_TABLE_NAME')) p.add_argument('--cols-to-not-sync', default=os.environ.get('COLS_TO_NOT_SYNC', '')) + p.add_argument('--primary-key-fields-override', + default=os.environ.get('PRIMARY_KEY_FIELDS_OVERRIDE', '')) p.add_argument('--progress-tracking-namespace', default=os.environ.get('PROGRESS_TRACKING_NAMESPACE', 'default')) p.add_argument('--progress-tracking-table-schema', @@ -287,7 +290,7 @@ def main() -> None: delete_temp_table_name: str = temp_table_base_name + '_delete' merge_temp_table_name: str = temp_table_base_name + '_merge' cols_to_not_sync: set[str] = set([c.strip().lower() for c in opts.cols_to_not_sync.split(',')]) - cols_to_not_sync.remove('') + cols_to_not_sync.discard('') proc_id: str = f'{socket.getfqdn()}+{int(datetime.now().timestamp())}' stop_event: EventClass = mp.Event() # For faster_fifo the ctor arg here is the queue byte size, not its item count size: @@ -305,15 +308,19 @@ def main() -> None: target=consumer_process, name='consumer', args=(opts, stop_event, queue, progress, proc_id, logger)) consumer_subprocess.start() - cursor.execute(f''' -SELECT [COLUMN_NAME] -FROM [INFORMATION_SCHEMA].[KEY_COLUMN_USAGE] -WHERE OBJECTPROPERTY(OBJECT_ID([CONSTRAINT_SCHEMA] + '.' + QUOTENAME([CONSTRAINT_NAME])), 'IsPrimaryKey') = 1 -AND [TABLE_SCHEMA] = :0 -AND [TABLE_NAME] = :1 -ORDER BY [ORDINAL_POSITION] - ''', (opts.target_db_table_schema, opts.target_db_table_name)) - primary_key_field_names: List[str] = [r[0] for r in cursor.fetchall()] + primary_key_field_names: List[str] + if opts.primary_key_fields_override.strip(): + primary_key_field_names = [x.strip() for x in opts.primary_key_fields_override.split(',')] + else: + cursor.execute(f''' + SELECT [COLUMN_NAME] + FROM [INFORMATION_SCHEMA].[KEY_COLUMN_USAGE] + WHERE OBJECTPROPERTY(OBJECT_ID([CONSTRAINT_SCHEMA] + '.' + QUOTENAME([CONSTRAINT_NAME])), 'IsPrimaryKey') = 1 + AND [TABLE_SCHEMA] = :0 + AND [TABLE_NAME] = :1 + ORDER BY [ORDINAL_POSITION] + ''', (opts.target_db_table_schema, opts.target_db_table_name)) + primary_key_field_names = [r[0] for r in cursor.fetchall()] field_names: List[str] = [] datetime_field_names: set[str] = set() @@ -442,7 +449,7 @@ def main() -> None: for _, (op, val) in queued_upserts.items(): # CTDS unfortunately completely ignores values for target-table IDENTITY cols when doing # a bulk_insert, so in that case we have to fall back to the slower MERGE mechanism: - if (not has_identity_col) and op in ('Snapshot', 'Insert'): + if (not has_identity_col) and op == 'Insert': # in ('Snapshot', 'Insert'): -- Snapshots can hit PK collisions! :( inserts.append(val) else: merges.append(val) @@ -559,7 +566,12 @@ def consumer_process(opts: argparse.Namespace, stop_event: EventClass, queue: Qu start_offset_by_partition: Dict[int, int] = { p.source_topic_partition: p.last_handled_message_offset + 1 for p in progress } - partitions: List[int] = consumer.list_topics(topic=opts.replay_topic).topics[opts.replay_topic].partitions + topics_meta: Dict[str, TopicMetadata] | None = consumer.list_topics( + topic=opts.replay_topic).topics # type: ignore[call-arg] + if topics_meta is None: + raise Exception(f'No partitions found for topic {opts.replay_topic}') + else: + partitions = list((topics_meta[opts.replay_topic].partitions or {}).keys()) topic_partitions: List[TopicPartition] = [TopicPartition( opts.replay_topic, p, start_offset_by_partition.get(p, OFFSET_BEGINNING) ) for p in partitions] @@ -580,19 +592,28 @@ def consumer_process(opts: argparse.Namespace, stop_event: EventClass, queue: Qu if msg_ctr % 5_000 == 0: logger.debug(f'Reached %s, apx queue depth %s', format_coordinates(msg), queue.qsize()) - if msg.error(): + err = msg.error() + if err: # noinspection PyProtectedMember - if msg.error().code() == KafkaError._PARTITION_EOF: + if err.code() == KafkaError._PARTITION_EOF: break else: raise Exception(msg.error()) - # noinspection PyArgumentList,PyTypeChecker - msg_key: Dict[str, Any] = avro_deserializer( - msg.key(), SerializationContext(msg.topic(), MessageField.KEY)) - # noinspection PyArgumentList,PyTypeChecker - msg_val: Dict[str, Any] = avro_deserializer( - msg.value(), SerializationContext(msg.topic(), MessageField.VALUE)) + topic = msg.topic() + if topic is None: + raise Exception('Unexpected None value for message topic()') + + msg_key: Optional[Dict[str, Any]] = None + raw_key = msg.key() + if raw_key is not None: + # noinspection PyNoneFunctionAssignment + msg_key = avro_deserializer(raw_key, SerializationContext(topic, MessageField.KEY)) # type: ignore[func-returns-value, arg-type] + msg_val: Optional[Dict[str, Any]] = None + raw_val = msg.value() + if raw_val is not None: + # noinspection PyNoneFunctionAssignment + msg_val = avro_deserializer(raw_val, SerializationContext(topic, MessageField.VALUE)) # type: ignore[func-returns-value, arg-type] queue.put((msg.partition(), msg.offset(), msg.timestamp()[1], msg_key, msg_val)) diff --git a/cdc_kafka/serializers/__init__.py b/cdc_kafka/serializers/__init__.py new file mode 100644 index 0000000..d1f32d1 --- /dev/null +++ b/cdc_kafka/serializers/__init__.py @@ -0,0 +1,59 @@ +import argparse +from abc import ABC, abstractmethod +from typing import TypeVar, Type, Tuple, Dict, Any, Optional, TYPE_CHECKING + +import confluent_kafka + +if TYPE_CHECKING: + from ..parsed_row import ParsedRow + from ..progress_tracking import ProgressEntry + from ..tracked_tables import TrackedTable + +SerializerAbstractType = TypeVar('SerializerAbstractType', bound='SerializerAbstract') + + +class DeserializedMessage(object): + def __init__(self, raw_msg: confluent_kafka.Message, key_dict: Optional[Dict[str, Any]], + value_dict: Optional[Dict[str, Any]]): + self.raw_msg = raw_msg + self.key_dict = key_dict + self.value_dict = value_dict + + +class SerializerAbstract(ABC): + + @abstractmethod + def register_table(self, table: 'TrackedTable') -> None: + pass + + @abstractmethod + def serialize_table_data_message(self, row: 'ParsedRow') -> Tuple[bytes, bytes]: + pass + + @abstractmethod + def serialize_progress_tracking_message(self, progress_entry: 'ProgressEntry') -> Tuple[bytes, Optional[bytes]]: + pass + + @abstractmethod + def serialize_metrics_message(self, metrics_namespace: str, metrics: Dict[str, Any]) -> Tuple[bytes, bytes]: + pass + + @abstractmethod + def serialize_snapshot_logging_message(self, snapshot_log: Dict[str, Any]) -> Tuple[None, bytes]: + pass + + @abstractmethod + def deserialize(self, msg: confluent_kafka.Message) -> DeserializedMessage: + pass + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> None: + pass + + @classmethod + @abstractmethod + def construct_with_options(cls: Type[SerializerAbstractType], opts: argparse.Namespace, + disable_writes: bool) -> SerializerAbstractType: + pass + + diff --git a/cdc_kafka/serializers/avro.py b/cdc_kafka/serializers/avro.py new file mode 100644 index 0000000..c1a4192 --- /dev/null +++ b/cdc_kafka/serializers/avro.py @@ -0,0 +1,735 @@ +import argparse +import collections +import datetime +import decimal +import functools +import io +import itertools +import json +import logging +import os +import struct +import time +from typing import Tuple, TypeVar, Type, List, Any, Dict, Callable, Literal, Sequence, Optional + +import confluent_kafka.avro +from avro.errors import AvroOutOfScaleException +from avro.schema import Schema +from bitarray import bitarray + +from . import SerializerAbstract, DeserializedMessage +from .. import constants +from ..metric_reporting.metrics import Metrics +from ..options import str2bool +from ..parsed_row import ParsedRow +from ..progress_tracking import ProgressEntry +from ..tracked_tables import TrackedTable + +logger = logging.getLogger(__name__) + +AvroSerializerType = TypeVar('AvroSerializerType', bound='AvroSerializer') + +COMPARE_CANONICAL_EVERY_NTH = 50_000 +AVRO_SCHEMA_NAMESPACE = "cdc_to_kafka" + +PROGRESS_TRACKING_SCHEMA_VERSION = '2' +PROGRESS_TRACKING_AVRO_KEY_SCHEMA = confluent_kafka.avro.loads(json.dumps({ + "name": f"{AVRO_SCHEMA_NAMESPACE}__progress_tracking_v{PROGRESS_TRACKING_SCHEMA_VERSION}__key", + "namespace": AVRO_SCHEMA_NAMESPACE, + "type": "record", + "fields": [ + { + "name": "topic_name", + "type": "string" + }, + { + "name": "progress_kind", + "type": { + "type": "enum", + "name": "progress_kind", + "symbols": [ + constants.CHANGE_ROWS_KIND, + constants.SNAPSHOT_ROWS_KIND + ] + } + } + ] +})) +PROGRESS_TRACKING_AVRO_VALUE_SCHEMA = confluent_kafka.avro.loads(json.dumps({ + "name": f"{AVRO_SCHEMA_NAMESPACE}__progress_tracking_v{PROGRESS_TRACKING_SCHEMA_VERSION}__value", + "namespace": AVRO_SCHEMA_NAMESPACE, + "type": "record", + "fields": [ + { + "name": "source_table_name", + "type": "string" + }, + { + "name": "change_table_name", + "type": "string" + }, + # ------------------------------------------------------------------------------------------------ + # These next two are defunct/deprecated as of v4 but remain here to ease the upgrade transition + # for anyone with existing progress recorded by earlier versions: + { + "name": "last_ack_partition", + "type": ["null", "int"] + }, + { + "name": "last_ack_offset", + "type": ["null", "long"] + }, + # ------------------------------------------------------------------------------------------------ + { + "name": "last_ack_position", + "type": [ + { + "type": "record", + "name": f"{constants.CHANGE_ROWS_KIND}_progress", + "namespace": AVRO_SCHEMA_NAMESPACE, + "fields": [ + { + "name": constants.LSN_NAME, + "type": "string", + }, + { + "name": constants.SEQVAL_NAME, + "type": "string", + }, + { + "name": constants.OPERATION_NAME, + "type": { + "type": "enum", + "name": constants.OPERATION_NAME, + "symbols": list(constants.CDC_OPERATION_NAME_TO_ID.keys()) + } + } + ] + }, + { + "type": "record", + "name": f"{constants.SNAPSHOT_ROWS_KIND}_progress", + "namespace": AVRO_SCHEMA_NAMESPACE, + "fields": [ + { + "name": "key_fields", + "type": { + "type": "map", + "values": ["string", "long"] + } + } + ] + } + ] + } + ] +})) + +SNAPSHOT_LOGGING_SCHEMA_VERSION = '1' +SNAPSHOT_LOGGING_AVRO_VALUE_SCHEMA = confluent_kafka.avro.loads(json.dumps({ + "name": f"{AVRO_SCHEMA_NAMESPACE}__snapshot_logging_v{SNAPSHOT_LOGGING_SCHEMA_VERSION}__value", + "namespace": AVRO_SCHEMA_NAMESPACE, + "type": "record", + "fields": [ + { + "name": "topic_name", + "type": "string" + }, + { + "name": "table_name", + "type": "string" + }, + { + "name": "action", + "type": "string" + }, + { + "name": "process_hostname", + "type": "string" + }, + { + "name": "event_time_utc", + "type": "string" + }, + { + "name": "key_schema_id", + "type": ["null", "long"] + }, + { + "name": "value_schema_id", + "type": ["null", "long"] + }, + { + "name": "partition_watermarks_low", + "type": ["null", { + "type": "map", + "values": "long" + }] + }, + { + "name": "partition_watermarks_high", + "type": ["null", { + "type": "map", + "values": "long" + }] + }, + { + "name": "starting_snapshot_index", + "type": ["null", { + "type": "map", + "values": ["string", "long"] + }] + }, + { + "name": "ending_snapshot_index", + "type": ["null", { + "type": "map", + "values": ["string", "long"] + }] + } + ] +})) + +METRICS_SCHEMA_VERSION = '2' + +METRICS_AVRO_KEY_SCHEMA = confluent_kafka.avro.loads(json.dumps({ + "name": f"{AVRO_SCHEMA_NAMESPACE}__metrics_v{METRICS_SCHEMA_VERSION}__key", + "namespace": AVRO_SCHEMA_NAMESPACE, + "type": "record", + "fields": [ + { + "name": "metrics_namespace", + "type": "string" + } + ] +})) + +METRICS_AVRO_VALUE_SCHEMA = confluent_kafka.avro.loads(json.dumps({ + "name": f"{AVRO_SCHEMA_NAMESPACE}__metrics_v{METRICS_SCHEMA_VERSION}__value", + "namespace": AVRO_SCHEMA_NAMESPACE, + "type": "record", + "fields": [ + { + "name": k, + "type": v + } for (k, v) in Metrics.FIELDS_AND_TYPES + ] +})) + + +class AvroTableDataSerializerMetadata(object): + __slots__ = ('key_schema_id', 'value_schema_id', 'key_field_ordinals', 'ordered_serializers', + 'key_field_names', 'value_field_names', 'all_cols_updated_enum_bytes') + + @staticmethod + def get_all_cols_updated_enum_bytes(col_count: int) -> bytes: + buf = io.BytesIO() + int_to_int(buf, col_count) + for p in range(1, col_count + 1): + int_to_int(buf, p) + buf.write(b'\x00') + return buf.getvalue() + + def __init__(self, key_schema_id: int, value_schema_id: int, key_field_ordinals: Sequence[int], + ordered_serializers: List[Callable[[io.BytesIO, Any], None]], + key_field_names: Sequence[str], value_field_names: Sequence[str]) -> None: + self.key_schema_id: int = key_schema_id + self.value_schema_id: int = value_schema_id + self.key_field_ordinals: Tuple[int, ...] = tuple(key_field_ordinals) + self.ordered_serializers: List[Callable[[io.BytesIO, Any], None]] = ordered_serializers + self.key_field_names: Tuple[str, ...] = tuple(key_field_names) + self.value_field_names: Tuple[str, ...] = tuple(value_field_names) + self.all_cols_updated_enum_bytes: bytes = AvroTableDataSerializerMetadata.get_all_cols_updated_enum_bytes( + len(ordered_serializers)) + + +class AvroSchemaGenerator(object): + _instance = None + + def __init__(self, always_use_avro_longs: bool, + avro_type_spec_overrides: Dict[str, str | Dict[str, str | int]]) -> None: + if AvroSchemaGenerator._instance is not None: + raise Exception('AvroSchemaGenerator class should be used as a singleton.') + + self.always_use_avro_longs: bool = always_use_avro_longs + self.normalized_avro_type_overrides: Dict[Tuple[str, str, str], str | Dict[str, str | int]] = {} + for k, v in avro_type_spec_overrides.items(): + if k.count('.') != 2: + raise Exception(f'Avro type spec override "{k}" was incorrectly specified. Please key this config in ' + 'the form .
.') + sn, tn, cn = k.split('.') + self.normalized_avro_type_overrides[(sn.lower(), tn.lower(), cn.lower())] = v + + AvroSchemaGenerator._instance = self + + def generate_key_schema(self, table: TrackedTable) -> Schema: + key_schema_fields = [self.get_record_field_schema( + table.schema_name, table.table_name, kf.name, kf.sql_type_name, kf.decimal_precision, + kf.decimal_scale, False + ) for kf in table.key_fields] + schema_json = { + "name": f"{table.schema_name}_{table.table_name}_cdc__key", + "namespace": AVRO_SCHEMA_NAMESPACE, + "type": "record", + "fields": key_schema_fields + } + return confluent_kafka.avro.loads(json.dumps(schema_json)) + + def generate_value_schema(self, table: TrackedTable) -> Schema: + # In CDC tables, all columns are nullable so that if the column is dropped from the source table, the capture + # instance need not be updated. We align with that by making the Avro value schema for all captured fields + # nullable (which also helps with maintaining future Avro schema compatibility). + value_schema_fields = [self.get_record_field_schema( + table.schema_name, table.table_name, vf.name, vf.sql_type_name, vf.decimal_precision, + vf.decimal_scale, True + ) for vf in table.value_fields] + value_field_names = [f.name for f in table.value_fields] + value_fields_plus_metadata_fields = AvroSchemaGenerator.get_cdc_metadata_fields_avro_schemas( + table.schema_name, table.table_name, value_field_names) + value_schema_fields + schema_json = { + "name": f"{table.schema_name}_{table.table_name}_cdc__value", + "namespace": AVRO_SCHEMA_NAMESPACE, + "type": "record", + "fields": value_fields_plus_metadata_fields + } + return confluent_kafka.avro.loads(json.dumps(schema_json)) + + def get_record_field_schema(self, db_schema_name: str, db_table_name: str, field_name: str, sql_type_name: str, + decimal_precision: int, decimal_scale: int, make_nullable: bool) -> Dict[str, Any]: + override_type = self.normalized_avro_type_overrides.get( + (db_schema_name.lower(), db_table_name.lower(), field_name.lower())) + if override_type: + avro_type = override_type + else: + if sql_type_name in ('decimal', 'numeric', 'money', 'smallmoney'): + if (not decimal_precision) or decimal_scale is None: + raise Exception(f"Field '{field_name}': For SQL decimal, money, or numeric types, the scale and " + f"precision must be provided.") + avro_type = { + "type": "bytes", + "logicalType": "decimal", + "precision": decimal_precision, + "scale": decimal_scale + } + elif sql_type_name == 'bigint': + avro_type = "long" + elif sql_type_name == 'bit': + avro_type = "boolean" + elif sql_type_name == 'float': + avro_type = "double" + elif sql_type_name == 'real': + avro_type = "float" + elif sql_type_name in ('int', 'smallint', 'tinyint'): + avro_type = "long" if self.always_use_avro_longs else "int" + # For date and time we don't respect always_use_avro_longs since the underlying type being `int` for these + # logical types is spelled out in the Avro spec: + elif sql_type_name == 'date': + avro_type = {"type": "int", "logicalType": "date"} + elif sql_type_name == 'time': + avro_type = {"type": "int", "logicalType": "time-millis"} + elif sql_type_name in ('datetime', 'datetime2', 'datetimeoffset', 'smalldatetime', + 'xml') + constants.SQL_STRING_TYPES: + avro_type = "string" + elif sql_type_name == 'uniqueidentifier': + avro_type = {"type": "string", "logicalType": "uuid"} + elif sql_type_name in ('binary', 'image', 'varbinary', 'rowversion'): + avro_type = "bytes" + else: + raise Exception(f"Field '{field_name}': I am unsure how to convert SQL type {sql_type_name} to Avro") + + if make_nullable: + return { + "name": field_name, + "type": [ + "null", + avro_type + ], + "default": None + } + else: + return { + "name": field_name, + "type": avro_type + } + + @staticmethod + # These fields are common to all change/snapshot data messages published to Kafka by this process + def get_cdc_metadata_fields_avro_schemas(db_schema_name: str, db_table_name: str, + source_field_names: List[str]) -> List[Dict[str, Any]]: + return [ + { + "name": constants.OPERATION_NAME, + "type": { + "type": "enum", + "name": f'{db_schema_name}_{db_table_name}{constants.OPERATION_NAME}', + "symbols": list(constants.CDC_OPERATION_NAME_TO_ID.keys()) + } + }, + { + # as ISO 8601 timestamp... either the change's tran_end_time OR the time the snapshot row was read: + "name": constants.EVENT_TIME_NAME, + "type": "string" + }, + { + "name": constants.LSN_NAME, + "type": ["null", "string"] + }, + { + "name": constants.SEQVAL_NAME, + "type": ["null", "string"] + }, + { + # Messages will list the names of all fields that were updated in the event (for snapshots or CDC insert + # records this will be all rows): + "name": constants.UPDATED_FIELDS_NAME, + "type": { + "type": "array", + "items": { + "type": "enum", + "name": f'{db_schema_name}_{db_table_name}{constants.UPDATED_FIELDS_NAME}', + "default": constants.UNRECOGNIZED_COLUMN_DEFAULT_NAME, + "symbols": [constants.UNRECOGNIZED_COLUMN_DEFAULT_NAME] + source_field_names + } + } + } + ] + + +class AvroSerializer(SerializerAbstract): + def __init__(self, schema_registry_url: str, always_use_avro_longs: bool, progress_topic_name: str, + snapshot_logging_topic_name: str, metrics_topic_name: str, + avro_type_spec_overrides: Dict[str, str | Dict[str, str | int]], disable_writes: bool) -> None: + self.always_use_avro_longs: bool = always_use_avro_longs + self.avro_type_spec_overrides: Dict[str, str | Dict[str, str | int]] = avro_type_spec_overrides + self.disable_writes: bool = disable_writes + self._schema_registry: confluent_kafka.avro.CachedSchemaRegistryClient = \ + confluent_kafka.avro.CachedSchemaRegistryClient(schema_registry_url) # type: ignore[call-arg] + self._confluent_serializer: confluent_kafka.avro.MessageSerializer = \ + confluent_kafka.avro.MessageSerializer(self._schema_registry) + self._tables: Dict[str, AvroTableDataSerializerMetadata] = {} + self._schema_generator: AvroSchemaGenerator = AvroSchemaGenerator( + always_use_avro_longs, avro_type_spec_overrides) + self._canonical_compare_ctr: Dict[str, int] = collections.defaultdict(int) + if progress_topic_name: + self._progress_key_schema_id: int = self._get_or_register_schema( + f'{progress_topic_name}-key', PROGRESS_TRACKING_AVRO_KEY_SCHEMA, + constants.DEFAULT_KEY_SCHEMA_COMPATIBILITY_LEVEL) + self._progress_value_schema_id: int = self._get_or_register_schema( + f'{progress_topic_name}-value', PROGRESS_TRACKING_AVRO_VALUE_SCHEMA, + constants.DEFAULT_VALUE_SCHEMA_COMPATIBILITY_LEVEL) + if snapshot_logging_topic_name: + self._snapshot_logging_schema_id: int = self._get_or_register_schema( + f'{snapshot_logging_topic_name}-value', SNAPSHOT_LOGGING_AVRO_VALUE_SCHEMA, + constants.DEFAULT_VALUE_SCHEMA_COMPATIBILITY_LEVEL) + if metrics_topic_name: + self._metrics_key_schema_id: int = self._get_or_register_schema( + f'{metrics_topic_name}-key', METRICS_AVRO_KEY_SCHEMA, + constants.DEFAULT_KEY_SCHEMA_COMPATIBILITY_LEVEL) + self._metrics_value_schema_id: int = self._get_or_register_schema( + f'{metrics_topic_name}-value', METRICS_AVRO_VALUE_SCHEMA, + constants.DEFAULT_VALUE_SCHEMA_COMPATIBILITY_LEVEL) + + def _get_or_register_schema(self, subject_name: str, schema: Schema, + compatibility_level: Literal["NONE", "FULL", "FORWARD", "BACKWARD"]) -> int: + # TODO: it turns out that if you try to re-register a schema that was previously registered but later superseded + # (e.g. in the case of adding and then later deleting a column), the schema registry will accept that and return + # you the previously-registered schema ID without updating the `latest` version associated with the registry + # subject, or verifying that the change is Avro-compatible. It seems like the way to handle this, per + # https://github.com/confluentinc/schema-registry/issues/1685, would be to detect the condition and delete the + # subject-version-number of that schema before re-registering it. Since subject-version deletion is not + # available in the `CachedSchemaRegistryClient` we use here--and since this is a rare case--I'm explicitly + # choosing to punt on it for the moment. The Confluent lib does now have a newer `SchemaRegistryClient` class + # which supports subject-version deletion, but changing this code to use it appears to be a non-trivial task. + + current_schema: Schema + schema_id_str, current_schema, _ = self._schema_registry.get_latest_schema(subject_name) + schema_id: int + if (current_schema is None or current_schema != schema) and not self.disable_writes: + logger.info('Schema for subject %s does not exist or is outdated; registering now.', subject_name) + schema_id = self._schema_registry.register(subject_name, schema) + logger.debug('Schema registered for subject %s: %s', subject_name, schema) + time.sleep(constants.KAFKA_CONFIG_RELOAD_DELAY_SECS) + if current_schema is None: + self._schema_registry.update_compatibility(compatibility_level, subject_name) + else: + schema_id = int(schema_id_str) + return schema_id + + def register_table(self, table: TrackedTable) -> None: + key_schema = self._schema_generator.generate_key_schema(table) + value_schema = self._schema_generator.generate_value_schema(table) + key_schema_id: int = self._get_or_register_schema(f'{table.topic_name}-key', key_schema, + constants.DEFAULT_KEY_SCHEMA_COMPATIBILITY_LEVEL) + value_schema_id: int = self._get_or_register_schema(f'{table.topic_name}-value', value_schema, + constants.DEFAULT_VALUE_SCHEMA_COMPATIBILITY_LEVEL) + ordered_serializers: List[Callable[[io.BytesIO, Any], None]] = [] + + serializer: Callable[[io.BytesIO, Any], None] + for vf in table.value_fields: + sql_type_name: str = vf.sql_type_name.lower() + if sql_type_name in ('decimal', 'numeric', 'money', 'smallmoney'): + if (not vf.decimal_precision) or vf.decimal_scale is None: + raise Exception(f"Field '{vf.name}': For SQL decimal, money, or numeric types, the scale and " + f"precision must be provided.") + serializer = functools.partial(decimal_to_decimal, scale=vf.decimal_scale) + elif sql_type_name == 'bit': + serializer = bool_to_bool + elif sql_type_name == 'float': + serializer = float_to_double + elif sql_type_name == 'real': + serializer = float_to_float + elif sql_type_name in ('int', 'smallint', 'tinyint', 'bigint'): + serializer = int_to_int + elif sql_type_name == 'date': + serializer = date_to_int + elif sql_type_name == 'time': + serializer = time_to_int + elif sql_type_name in ('datetime', 'datetime2', 'datetimeoffset', 'smalldatetime'): + serializer = datetime_to_string + elif sql_type_name in ('xml', 'uniqueidentifier') + constants.SQL_STRING_TYPES: + serializer = string_to_string + elif sql_type_name in ('binary', 'image', 'varbinary', 'rowversion'): + serializer = bytes_to_bytes + else: + raise Exception(f"Field '{vf.name}': I am unsure how to convert SQL type {sql_type_name} to Avro") + + ordered_serializers.append(serializer) + + self._tables[table.topic_name] = AvroTableDataSerializerMetadata( + key_schema_id, value_schema_id, table.key_field_source_table_ordinals, ordered_serializers, + table.key_field_names, table.value_field_names) + + def serialize_table_data_message(self, row: ParsedRow) -> Tuple[bytes, bytes]: + metadata = self._tables[row.destination_topic] + key_writer = io.BytesIO() + key_writer.write(struct.pack('>bI', 0, metadata.key_schema_id)) + value_writer = io.BytesIO() + value_writer.write(struct.pack('>bI', 0, metadata.value_schema_id)) + int_to_int(value_writer, row.operation_id) + as_bytes: bytes = row.event_db_time.isoformat().encode("utf-8") + int_to_int(value_writer, len(as_bytes)) + value_writer.write(struct.pack(f"{len(as_bytes)}s", as_bytes)) + if row.change_idx is None or row.operation_id == constants.SNAPSHOT_OPERATION_ID: + value_writer.write(b'\x00\x00') + else: + value_writer.write(b'\x02') + as_bytes = f',0x{row.change_idx.lsn.hex()}'.encode("utf-8") + value_writer.write(struct.pack("23s", as_bytes)) + value_writer.write(b'\x02') + as_bytes = f',0x{row.change_idx.seqval.hex()}'.encode("utf-8") + value_writer.write(struct.pack("23s", as_bytes)) + if row.operation_id in (constants.SNAPSHOT_OPERATION_ID, constants.INSERT_OPERATION_ID, + constants.DELETE_OPERATION_ID): + value_writer.write(metadata.all_cols_updated_enum_bytes) + else: + bits = bitarray() + bits.frombytes(row.cdc_update_mask) + bits.reverse() + int_to_int(value_writer, bits.count()) + for i, bit in enumerate(bits): + if bit: + int_to_int(value_writer, i + 1) + value_writer.write(b'\x00') + for ix, f in enumerate(metadata.ordered_serializers): + if row.table_data_cols[ix] is None: + value_writer.write(b'\x00') + else: + value_writer.write(b'\x02') + f(value_writer, row.table_data_cols[ix]) + for ix in metadata.key_field_ordinals: + f = metadata.ordered_serializers[ix - 1] + datum = row.table_data_cols[ix - 1] + f(key_writer, datum) + serialized_key: bytes = key_writer.getvalue() + key_writer.close() + serialized_value: bytes = value_writer.getvalue() + value_writer.close() + + if self._canonical_compare_ctr.get(row.destination_topic, 0) % COMPARE_CANONICAL_EVERY_NTH == 0: + self.compare_canonical(metadata, row, serialized_key, serialized_value) + + self._canonical_compare_ctr[row.destination_topic] += 1 + + return serialized_key, serialized_value + + # Not normally used, but here to help debug custom Avro serialization by comparing against that used + # by the Confluent library: + def compare_canonical(self, metadata: AvroTableDataSerializerMetadata, row: ParsedRow, + serialized_key: bytes, serialized_value: bytes) -> None: + + key_dict = dict(zip(metadata.key_field_names, row.ordered_key_field_values)) + value_dict = dict(zip(metadata.value_field_names, row.table_data_cols)) + + if row.operation_id == constants.SNAPSHOT_OPERATION_ID: + value_dict[constants.OPERATION_NAME] = constants.SNAPSHOT_OPERATION_NAME + value_dict[constants.LSN_NAME] = None + value_dict[constants.SEQVAL_NAME] = None + else: + change_idx = row.change_idx + value_dict.update(change_idx.as_dict()) + + if row.operation_id in (constants.PRE_UPDATE_OPERATION_ID, constants.POST_UPDATE_OPERATION_ID): + arr = bitarray() + arr.frombytes(row.cdc_update_mask) + arr.reverse() + value_dict[constants.UPDATED_FIELDS_NAME] = list(itertools.compress(metadata.value_field_names, arr)) + else: + value_dict[constants.UPDATED_FIELDS_NAME] = list(metadata.value_field_names) + + value_dict[constants.EVENT_TIME_NAME] = row.event_db_time.isoformat() + dates_transformed = {k: v.isoformat() for k, v in value_dict.items() if type(v) is datetime.datetime} + value_dict.update(dates_transformed) + + comp_key: bytes = self._confluent_serializer.encode_record_with_schema_id( # type: ignore[assignment] + metadata.key_schema_id, key_dict, True) + if serialized_key != comp_key: + # import pdb; pdb.set_trace() + raise Exception( + f'Avro serialization does not match the canonical library serialization. Key {key_dict} for message ' + f'to topic {row.destination_topic} with schema ID {metadata.key_schema_id} serialized as ' + f'"{serialized_key!r}" but canonical form was "{comp_key!r}".') + + comp_value: bytes = self._confluent_serializer.encode_record_with_schema_id( # type: ignore[assignment] + metadata.value_schema_id, value_dict, False) + if serialized_value != comp_value: + # import pdb; pdb.set_trace() + raise Exception( + f'Avro serialization does not match the canonical library serialization. Value {value_dict} for ' + f'message to topic {row.destination_topic} with schema ID {metadata.value_schema_id} serialized ' + f'as "{serialized_value!r}" but canonical form was "{comp_value!r}".') + + def serialize_progress_tracking_message(self, progress_entry: ProgressEntry) -> Tuple[bytes, Optional[bytes]]: + k: bytes = self._confluent_serializer.encode_record_with_schema_id( # type: ignore[assignment] + self._progress_key_schema_id, progress_entry.key, True) + if progress_entry.value is None: + return k, None + v: bytes = self._confluent_serializer.encode_record_with_schema_id( # type: ignore[assignment] + self._progress_value_schema_id, progress_entry.value, False) + return k, v + + def serialize_metrics_message(self, metrics_namespace: str, metrics: Dict[str, Any]) -> Tuple[bytes, bytes]: + k: bytes = self._confluent_serializer.encode_record_with_schema_id( # type: ignore[assignment] + self._metrics_key_schema_id, {'metrics_namespace': metrics_namespace}, True) + v: bytes = self._confluent_serializer.encode_record_with_schema_id( # type: ignore[assignment] + self._metrics_value_schema_id, metrics, False) + return k, v + + def serialize_snapshot_logging_message(self, snapshot_log: Dict[str, Any]) -> Tuple[None, bytes]: + v: bytes = self._confluent_serializer.encode_record_with_schema_id( # type: ignore[assignment] + self._snapshot_logging_schema_id, snapshot_log, False) + return None, v + + def deserialize(self, msg: confluent_kafka.Message) -> DeserializedMessage: + # noinspection PyArgumentList + return DeserializedMessage(msg, + self._confluent_serializer.decode_message(msg.key(), is_key=True), + self._confluent_serializer.decode_message(msg.value(), is_key=False)) + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + '--schema-registry-url', default=os.environ.get('SCHEMA_REGISTRY_URL'), + help='URL to your Confluent Schema Registry, e.g. "http://localhost:8081"') + parser.add_argument( + '--always-use-avro-longs', type=str2bool, nargs='?', const=True, + default=str2bool(os.environ.get('ALWAYS_USE_AVRO_LONGS', '0')), + help="Defaults to False. If set to True, Avro schemas produced/registered by this process will " + "use the Avro `long` type instead of the `int` type for fields corresponding to SQL Server " + "INT, SMALLINT, or TINYINT columns. This can be used to future-proof in cases where the column " + "size may need to be upgraded in the future, at the potential cost of increased storage or " + "memory space needs in consuming processes. Note that if this change is made for existing " + "topics, the schema registration attempt will violate Avro FORWARD compatibility checks (the " + "default used by this process), meaning that you may need to manually override the schema " + "registry compatibility level for any such topics first.") + parser.add_argument( + '--avro-type-spec-overrides', + default=os.environ.get('AVRO_TYPE_SPEC_OVERRIDES', {}), type=json.loads, + help='Optional JSON object that maps schema.table.column names to a string or object indicating the ' + 'Avro schema type specification you want to use for the field. This will override the default ' + 'mapping of SQL types to Avro types otherwise used and found in avro.py. Note that setting ' + 'this only changes the generated schema and will NOT affect the way values are passed to the ' + 'Avro serialization library, so any overriding type specified should be compatible with the ' + 'SQL/Python types of the actual data. Example: `{"dbo.order.orderid": "long"}` could be used ' + 'to specify the use of an Avro `long` type for a source DB column that is only a 32-bit INT, ' + 'perhaps in preparation for a future DB column change.') + + @classmethod + def construct_with_options(cls: Type[AvroSerializerType], opts: argparse.Namespace, + disable_writes: bool) -> AvroSerializerType: + if not opts.schema_registry_url: + raise Exception('AvroSerializer cannot be used without specifying a value for SCHEMA_REGISTRY_URL') + metrics_topic_name: str = hasattr(opts, 'kafka_metrics_topic') and opts.kafka_metrics_topic or '' + return cls(opts.schema_registry_url, opts.always_use_avro_longs, opts.progress_topic_name, + opts.snapshot_logging_topic_name, metrics_topic_name, opts.avro_type_spec_overrides, + disable_writes) + + +def decimal_to_decimal(writer: io.BytesIO, datum: decimal.Decimal, scale: int) -> None: + sign, digits, exp = datum.as_tuple() + if (-1 * int(exp)) > scale: + raise AvroOutOfScaleException(scale, datum, exp) # type: ignore[no-untyped-call] + + unscaled_datum = 0 + for digit in digits: + unscaled_datum = (unscaled_datum * 10) + digit + + bits_req = unscaled_datum.bit_length() + 1 + if sign: + unscaled_datum = (1 << bits_req) - unscaled_datum + + bytes_req = bits_req // 8 + padding_bits = ~((1 << bits_req) - 1) if sign else 0 + packed_bits = padding_bits | unscaled_datum + + bytes_req += 1 if (bytes_req << 3) < bits_req else 0 + int_to_int(writer, bytes_req) + for index in range(bytes_req - 1, -1, -1): + bits_to_write = packed_bits >> (8 * index) + writer.write(bytearray([bits_to_write & 0xFF])) + + +def int_to_int(writer: io.BytesIO, datum: int) -> None: + datum = (datum << 1) ^ (datum >> 63) + while datum & ~0x7F: + writer.write(bytearray([(datum & 0x7F) | 0x80])) + datum >>= 7 + writer.write(bytearray([datum])) + + +def bool_to_bool(writer: io.BytesIO, datum: bool) -> None: + writer.write(bytearray([bool(datum)])) + + +def float_to_double(writer: io.BytesIO, datum: float) -> None: + writer.write(struct.Struct(" None: + writer.write(struct.Struct(" None: + delta_date = datum - datetime.date(1970, 1, 1) + int_to_int(writer, delta_date.days) + + +def time_to_int(writer: io.BytesIO, datum: datetime.time) -> None: + milliseconds = datum.hour * 3600000 + datum.minute * 60000 + datum.second * 1000 + datum.microsecond // 1000 + int_to_int(writer, milliseconds) + + +def datetime_to_string(writer: io.BytesIO, datum: datetime.datetime) -> None: + as_bytes: bytes = datum.isoformat().encode("utf-8") + int_to_int(writer, len(as_bytes)) + writer.write(struct.pack(f"{len(as_bytes)}s", as_bytes)) + + +def bytes_to_bytes(writer: io.BytesIO, datum: bytes) -> None: + int_to_int(writer, len(datum)) + writer.write(struct.pack(f"{len(datum)}s", datum)) + + +def string_to_string(writer: io.BytesIO, datum: str) -> None: + as_bytes: bytes = datum.encode("utf-8") + int_to_int(writer, len(as_bytes)) + writer.write(struct.pack(f"{len(as_bytes)}s", as_bytes)) diff --git a/cdc_kafka/show_snapshot_history.py b/cdc_kafka/show_snapshot_history.py index afe7bcb..0470dd7 100644 --- a/cdc_kafka/show_snapshot_history.py +++ b/cdc_kafka/show_snapshot_history.py @@ -4,10 +4,10 @@ from typing import List, Any, Dict import confluent_kafka + from tabulate import tabulate -from cdc_kafka import kafka, kafka_oauth -from . import constants +from . import kafka, constants, options from .metric_reporting import accumulator logger = logging.getLogger(__name__) @@ -21,18 +21,14 @@ def main() -> None: - p = argparse.ArgumentParser() - p.add_argument('--topic-names', required=True) - p.add_argument('--schema-registry-url', required=True) - p.add_argument('--kafka-bootstrap-servers', required=True) - p.add_argument('--snapshot-logging-topic-name', required=True) - p.add_argument('--extra-kafka-consumer-config', type=json.loads, default={}) - p.add_argument('--script-output-file', type=argparse.FileType('w')) - p.add_argument('--extra-kafka-cli-command-arg', nargs='*') - - kafka_oauth.add_kafka_oauth_arg(p) - opts, _ = p.parse_known_args() - topic_names: List[str] = opts.topic_names.split(',') + def add_args(p: argparse.ArgumentParser) -> None: + p.add_argument('--topic-names', required=True) + p.add_argument('--script-output-file', type=argparse.FileType('w')) + p.add_argument('--extra-kafka-cli-command-arg', nargs='*') + + opts, _, serializer = options.get_options_and_metrics_reporters(add_args) + print(opts.topic_names) + topic_names: List[str] = [x.strip() for x in opts.topic_names.strip().split(',') if x.strip()] display_table: List[List[str]] = [] completions_seen_since_start: Dict[str, bool] = {tn: False for tn in topic_names} last_starts: Dict[str, Dict[str, Any]] = {tn: {} for tn in topic_names} @@ -52,11 +48,11 @@ def main() -> None: "Value schema ID", ] - with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers, opts.schema_registry_url, + with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers, opts.extra_kafka_consumer_config, {}, disable_writing=True) as kafka_client: for msg in kafka_client.consume_all(opts.snapshot_logging_topic_name): - # noinspection PyTypeChecker,PyArgumentList - log = dict(msg.value()) + deser_msg = serializer.deserialize(msg) + log = deser_msg.value_dict or {} consumed_count += 1 if log['topic_name'] not in topic_names: continue @@ -82,7 +78,7 @@ def main() -> None: print(f''' Consumed {consumed_count} messages from snapshot logging topic {opts.snapshot_logging_topic_name}. -{relevant_count} were related to requested topics {opts.topic_names}. +{relevant_count} were related to requested topics {topic_names}. ''') if not relevant_count: @@ -94,7 +90,14 @@ def main() -> None: watermarks = kafka_client.get_topic_watermarks(topic_names) for topic_name in topic_names: - all_topic_configs = kafka_client.get_topic_config(topic_name) + try: + all_topic_configs = kafka_client.get_topic_config(topic_name) + except confluent_kafka.KafkaException as e: + if e.args[0].code() == confluent_kafka.KafkaError.UNKNOWN_TOPIC_OR_PART: + print(f'Topic {topic_name} does not seem to exist; skipping.') + continue + else: + raise topic_has_delete_cleanup_policy = 'delete' in all_topic_configs['cleanup.policy'].value topic_level_retention_configs = { k: v.value for k, v in all_topic_configs.items() diff --git a/cdc_kafka/sql_query_subprocess.py b/cdc_kafka/sql_query_subprocess.py index 1dfaa71..837bee3 100644 --- a/cdc_kafka/sql_query_subprocess.py +++ b/cdc_kafka/sql_query_subprocess.py @@ -30,7 +30,7 @@ class SQLQueryResult(NamedTuple): reflected_query_request_metadata: Any query_executed_utc: datetime.datetime query_took_sec: float - result_rows: List[pyodbc.Row | parsed_row.ParsedRow] + result_rows: List[parsed_row.ParsedRow] query_params: Sequence[Any] @@ -116,7 +116,7 @@ def querier_thread(self) -> None: start_time = time.perf_counter() with db_conn.cursor() as cursor: if request.query_param_types is not None: - cursor.setinputsizes(request.query_param_types) + cursor.setinputsizes(request.query_param_types) # type: ignore[arg-type] retry_count = 0 while True: try: @@ -146,10 +146,12 @@ def querier_thread(self) -> None: SQLQueryResult(request.queue_name, request.query_metadata_to_reflect, query_executed_utc, query_took_sec, result_rows, request.query_params) ) - except (KeyboardInterrupt, pyodbc.OperationalError) as exc: + except pyodbc.OperationalError as exc: # 08S01 is the error code for "Communication link failure" which may be raised in response to KeyboardInterrupt - if exc is pyodbc.OperationalError and not exc.args[0].startswith('08S01'): + if not exc.args[0].startswith('08S01'): raise exc + except KeyboardInterrupt: + pass except Exception as exc: logger.exception('SQL querier thread raised an exception.', exc_info=exc) finally: diff --git a/cdc_kafka/tracked_tables.py b/cdc_kafka/tracked_tables.py index 8b1f2b5..783d61c 100644 --- a/cdc_kafka/tracked_tables.py +++ b/cdc_kafka/tracked_tables.py @@ -1,16 +1,12 @@ import hashlib -import itertools import logging import uuid -from typing import Tuple, Dict, List, Any, Callable, Optional, Generator, TYPE_CHECKING, Mapping, Sequence +from typing import Tuple, List, Any, Optional, Generator, TYPE_CHECKING, Mapping, Sequence, Dict -from avro.schema import Schema -import bitarray import pyodbc -from .avro import avro_transform_fn_from_sql_type from . import constants, change_index, options, sql_queries, sql_query_subprocess, parsed_row, \ - helpers, avro, progress_tracking + helpers, progress_tracking if TYPE_CHECKING: from .metric_reporting import accumulator @@ -20,7 +16,7 @@ class TrackedField(object): __slots__ = 'name', 'sql_type_name', 'change_table_ordinal', 'primary_key_ordinal', 'decimal_precision', \ - 'decimal_scale', 'transform_fn', 'truncate_after' + 'decimal_scale', 'truncate_after' def __init__(self, name: str, sql_type_name: str, change_table_ordinal: int, primary_key_ordinal: int, decimal_precision: int, decimal_scale: int, truncate_after: int = 0) -> None: @@ -30,24 +26,22 @@ def __init__(self, name: str, sql_type_name: str, change_table_ordinal: int, pri self.primary_key_ordinal: int = primary_key_ordinal self.decimal_precision: int = decimal_precision self.decimal_scale: int = decimal_scale - self.transform_fn: Optional[Callable[[Any], Any]] = avro_transform_fn_from_sql_type(sql_type_name) - if truncate_after and self.sql_type_name not in constants.SQL_STRING_TYPES: - raise Exception(f'A truncation length was specified for field {name} but it does not appear to be a ' - f'string field (SQL type is {sql_type_name}).') + if truncate_after: + if self.sql_type_name not in constants.SQL_STRING_TYPES: + raise Exception(f'A truncation length was specified for field {name} but it does not appear to be a ' + f'string field (SQL type is {sql_type_name}).') self.truncate_after: int = truncate_after class TrackedTable(object): def __init__(self, db_conn: pyodbc.Connection, metrics_accumulator: 'accumulator.Accumulator', - sql_query_processor: sql_query_subprocess.SQLQueryProcessor, - schema_generator: avro.AvroSchemaGenerator, schema_name: str, table_name: str, + sql_query_processor: sql_query_subprocess.SQLQueryProcessor, schema_name: str, table_name: str, capture_instance_name: str, topic_name: str, min_lsn: bytes, snapshot_allowed: bool, db_row_batch_size: int, progress_tracker: 'progress_tracking.ProgressTracker') -> None: self._db_conn: pyodbc.Connection = db_conn self._metrics_accumulator: 'accumulator.Accumulator' = metrics_accumulator self._sql_query_processor: sql_query_subprocess.SQLQueryProcessor = sql_query_processor - self._schema_generator: avro.AvroSchemaGenerator = schema_generator self.schema_name: str = schema_name self.table_name: str = table_name self.capture_instance_name: str = capture_instance_name @@ -56,23 +50,20 @@ def __init__(self, db_conn: pyodbc.Connection, metrics_accumulator: 'accumulator self.fq_name: str = f'{schema_name}.{table_name}' self.db_row_batch_size: int = db_row_batch_size self.progress_tracker: progress_tracking.ProgressTracker = progress_tracker + self.truncate_indexes: Dict[int, int] = {} # Most of the below properties are not set until sometime after `finalize_table` is called: - self.key_schema_id: int = -1 - self.value_schema_id: int = -1 - self.key_schema: Optional[Schema] = None - self.value_schema: Optional[Schema] = None self.key_fields: Tuple[TrackedField, ...] = tuple() self.value_fields: Tuple[TrackedField, ...] = tuple() + self.key_field_names: Tuple[str, ...] = tuple() + self.value_field_names: List[str] = [] + self.key_field_source_table_ordinals: Tuple[int, ...] = tuple() self.max_polled_change_index: change_index.ChangeIndex = change_index.LOWEST_CHANGE_INDEX self.change_reads_are_lagging: bool = False self.snapshot_complete: bool = False self.min_lsn: bytes = min_lsn - self._key_field_names: Tuple[str, ...] = tuple() - self._key_field_source_table_ordinals: Tuple[int, ...] = tuple() - self._value_field_names: List[str] = [] self._last_read_key_for_snapshot: Optional[Sequence[Any]] = None self._odbc_columns: Tuple[pyodbc.Row, ...] = tuple() self._change_rows_query: str = '' @@ -94,16 +85,19 @@ def __init__(self, db_conn: pyodbc.Connection, metrics_accumulator: 'accumulator def last_read_key_for_snapshot_display(self) -> Optional[str]: if not self._last_read_key_for_snapshot: return None - return ', '.join([f'{k}: {v}' for k, v in zip(self._key_field_names, self._last_read_key_for_snapshot)]) + return ', '.join([f'{k}: {v}' for k, v in zip(self.key_field_names, self._last_read_key_for_snapshot)]) def append_field(self, field: TrackedField) -> None: + field_ix: int = len(self._fields_added_pending_finalization) self._fields_added_pending_finalization.append(field) + if field.truncate_after: + self.truncate_indexes[field_ix] = field.truncate_after def get_source_table_count(self, low_key: Tuple[Any, ...], high_key: Tuple[Any, ...]) -> int: with self._db_conn.cursor() as cursor: - q, p = sql_queries.get_table_count(self.schema_name, self.table_name, self._key_field_names, + q, p = sql_queries.get_table_count(self.schema_name, self.table_name, self.key_field_names, self._odbc_columns) - cursor.setinputsizes(p) + cursor.setinputsizes(p) # type: ignore[arg-type] cursor.execute(q, low_key + high_key) res: int = cursor.fetchval() return res @@ -113,7 +107,7 @@ def get_change_table_counts(self, highest_change_index: change_index.ChangeIndex deletes, inserts, updates = 0, 0, 0 q, p = sql_queries.get_change_table_count_by_operation( helpers.quote_name(helpers.get_fq_change_table_name(self.capture_instance_name))) - cursor.setinputsizes(p) + cursor.setinputsizes(p) # type: ignore[arg-type] cursor.execute(q, (highest_change_index.lsn, highest_change_index.seqval, highest_change_index.operation)) for row in cursor.fetchall(): if row[1] == 1: @@ -132,7 +126,6 @@ def finalize_table( self, start_after_change_table_index: change_index.ChangeIndex, prior_change_table_max_index: Optional[change_index.ChangeIndex], start_from_key_for_snapshot: Optional[Mapping[str, Any]], lsn_gap_handling: str, - schema_id_getter: Optional[Callable[[str, Schema, Schema], Tuple[int, int]]] = None, allow_progress_writes: bool = False ) -> None: if self._finalized: @@ -169,7 +162,7 @@ def finalize_table( self.max_polled_change_index = start_after_change_table_index self.value_fields = tuple(sorted(self._fields_added_pending_finalization, key=lambda f: f.change_table_ordinal)) - self._value_field_names = [f.name for f in self.value_fields] + self.value_field_names = [f.name for f in self.value_fields] self._fields_added_pending_finalization = [] key_fields = [f for f in self.value_fields if f.primary_key_ordinal is not None] @@ -181,22 +174,13 @@ def finalize_table( self._has_pk = True self.key_fields = tuple(sorted(key_fields, key=lambda f: f.primary_key_ordinal)) - self._key_field_names = tuple([kf.name for kf in self.key_fields]) - self._key_field_source_table_ordinals = tuple([kf.change_table_ordinal for kf in self.key_fields]) - self.key_schema = self._schema_generator.generate_key_schema(self.schema_name, self.table_name, - self.key_fields) - self.value_schema = self._schema_generator.generate_value_schema(self.schema_name, self.table_name, - self.value_fields) - - if schema_id_getter: - self.key_schema_id, self.value_schema_id = schema_id_getter(self.topic_name, self.key_schema, - self.value_schema) - + self.key_field_names = tuple([kf.name for kf in self.key_fields]) + self.key_field_source_table_ordinals = tuple([kf.change_table_ordinal for kf in self.key_fields]) capture_instance_name = helpers.get_fq_change_table_name(self.capture_instance_name) with self._db_conn.cursor() as cursor: q, p = sql_queries.get_change_table_index_cols() - cursor.setinputsizes(p) + cursor.setinputsizes(p) # type: ignore[arg-type] cursor.execute(q, capture_instance_name) change_table_clustered_idx_cols = [r[0] for r in cursor.fetchall()] @@ -212,14 +196,14 @@ def finalize_table( f'were: {change_table_clustered_idx_cols}') self._change_rows_query, self._change_rows_query_param_types = sql_queries.get_change_rows( - self.db_row_batch_size, helpers.quote_name(capture_instance_name), self._value_field_names, + self.db_row_batch_size, helpers.quote_name(capture_instance_name), self.value_field_names, change_table_clustered_idx_cols) if not self.snapshot_allowed: self.snapshot_complete = True else: columns_actually_on_base_table = {x[3] for x in self._odbc_columns} - columns_no_longer_on_base_table = set(self._value_field_names) - columns_actually_on_base_table + columns_no_longer_on_base_table = set(self.value_field_names) - columns_actually_on_base_table if columns_no_longer_on_base_table: logger.warning('Some column(s) found in the capture instance appear to no longer be present on base ' @@ -228,14 +212,14 @@ def finalize_table( ', '.join(columns_no_longer_on_base_table)) if self._has_pk: self._snapshot_rows_query, self._snapshot_rows_query_param_types = sql_queries.get_snapshot_rows( - self.db_row_batch_size, self.schema_name, self.table_name, self._value_field_names, - columns_no_longer_on_base_table, self._key_field_names, False, self._odbc_columns) + self.db_row_batch_size, self.schema_name, self.table_name, self.value_field_names, + columns_no_longer_on_base_table, self.key_field_names, False, self._odbc_columns) if start_from_key_for_snapshot == constants.SNAPSHOT_COMPLETION_SENTINEL: self.snapshot_complete = True elif start_from_key_for_snapshot: key_min_tuple = tuple(self._get_min_key_value() or []) - start_key_tuple = tuple([start_from_key_for_snapshot[kfn] for kfn in self._key_field_names]) + start_key_tuple = tuple([start_from_key_for_snapshot[kfn] for kfn in self.key_field_names]) if key_min_tuple and key_min_tuple == start_key_tuple: self.snapshot_complete = True @@ -244,27 +228,23 @@ def finalize_table( constants.SNAPSHOT_COMPLETION_SENTINEL) else: if allow_progress_writes: - start_key_dict = dict(zip(self._key_field_names, start_key_tuple)) - self.progress_tracker.log_snapshot_resumed(self.topic_name, self.fq_name, - self.key_schema_id, self.value_schema_id, - start_key_dict) + start_key_dict = dict(zip(self.key_field_names, start_key_tuple)) + self.progress_tracker.log_snapshot_resumed(self.topic_name, self.fq_name, start_key_dict) self._last_read_key_for_snapshot = start_key_tuple else: key_max = self._get_max_key_value() if key_max: - key_max_map = dict(zip(self._key_field_names, key_max)) + key_max_map = dict(zip(self.key_field_names, key_max)) logger.info('Table %s is starting a full snapshot, working back from max key %s', self.fq_name, key_max_map) self._initial_snapshot_rows_query, _ = sql_queries.get_snapshot_rows( - self.db_row_batch_size, self.schema_name, self.table_name, self._value_field_names, - columns_no_longer_on_base_table, self._key_field_names, True, self._odbc_columns) + self.db_row_batch_size, self.schema_name, self.table_name, self.value_field_names, + columns_no_longer_on_base_table, self.key_field_names, True, self._odbc_columns) if allow_progress_writes: - self.progress_tracker.log_snapshot_started(self.topic_name, self.fq_name, - self.key_schema_id, self.value_schema_id, - key_max_map) + self.progress_tracker.log_snapshot_started(self.topic_name, self.fq_name, key_max_map) self._last_read_key_for_snapshot = None else: @@ -331,7 +311,7 @@ def retrieve_snapshot_query_results(self) -> Generator[parsed_row.ParsedRow, Non last_read: str = '' if self._last_read_key_for_snapshot: last_read = ', '.join([f'{k}: {v}' for k, v in - zip(self._key_field_names, self._last_read_key_for_snapshot)]) + zip(self.key_field_names, self._last_read_key_for_snapshot)]) logger.info("SNAPSHOT COMPLETED for table %s. Last read key: (%s)", self.fq_name, last_read) self._last_read_key_for_snapshot = None self.snapshot_complete = True @@ -379,9 +359,9 @@ def get_change_rows_per_second(self) -> int: @staticmethod def cut_str_to_bytes(s: str, max_bytes: int) -> Tuple[int, str]: # Mostly copied from https://github.com/halloleo/unicut/blob/master/truncate.py - def safe_b_of_i(b, i): + def safe_b_of_i(encoded: bytes, i: int) -> int: try: - return b[i] + return encoded[i] except IndexError: return 0 @@ -414,68 +394,41 @@ def safe_b_of_i(b, i): def _parse_db_row(self, db_row: pyodbc.Row) -> parsed_row.ParsedRow: operation_id, event_db_time, lsn, seqval, update_mask, *table_cols = db_row - operation_name = constants.CDC_OPERATION_ID_TO_NAME[operation_id] - - value_dict = {} - extra_headers = {} - for ix, fld in enumerate(self.value_fields): - val = table_cols[ix] - if fld.transform_fn: - val = fld.transform_fn(val) - # The '* 4' below is because that's the maximum possible byte length of a UTF-8 encoded character--just - # trying to optimize away the extra code inside the `if` block when possible: - if val and fld.truncate_after and len(val) * 4 > fld.truncate_after: - original_encoded_length = len(val.encode('utf-8')) - if original_encoded_length > fld.truncate_after: - new_encoded_length, val = TrackedTable.cut_str_to_bytes(val, fld.truncate_after) - extra_headers[f'cdc_to_kafka_truncated_field__{fld.name}'] = \ - f'{original_encoded_length},{new_encoded_length}' - value_dict[fld.name] = val if operation_id == constants.SNAPSHOT_OPERATION_ID: - row_kind = constants.SNAPSHOT_ROWS_KIND change_idx = change_index.LOWEST_CHANGE_INDEX - value_dict[constants.OPERATION_NAME] = constants.SNAPSHOT_OPERATION_NAME - value_dict[constants.LSN_NAME] = None - value_dict[constants.SEQVAL_NAME] = None else: - row_kind = constants.CHANGE_ROWS_KIND change_idx = change_index.ChangeIndex(lsn, seqval, operation_id) - value_dict.update(change_idx.to_avro_ready_dict()) - if operation_id in (constants.PRE_UPDATE_OPERATION_ID, constants.POST_UPDATE_OPERATION_ID): - value_dict[constants.UPDATED_FIELDS_NAME] = self._updated_col_names_from_mask(update_mask) - else: - value_dict[constants.UPDATED_FIELDS_NAME] = self._value_field_names + extra_headers: Dict[str, str | bytes] = {} - value_dict[constants.EVENT_TIME_NAME] = event_db_time.isoformat() + for ix, max_length in self.truncate_indexes.items(): + # The '* 4' below is because that's the maximum possible byte length of a UTF-8 encoded character--just + # trying to optimize away the extra code inside the `if` block when possible: + if table_cols[ix] and len(table_cols[ix]) * 4 > max_length: + original_encoded_length = len(table_cols[ix].encode('utf-8')) + if original_encoded_length > max_length: + new_encoded_length, table_cols[ix] = TrackedTable.cut_str_to_bytes(table_cols[ix], max_length) + extra_headers[f'cdc_to_kafka_truncated_field__{self.value_field_names[ix]}'] = \ + f'{original_encoded_length},{new_encoded_length}' ordered_key_field_values: List[Any] if self._has_pk: - ordered_key_field_values = [table_cols[kfo - 1] for kfo in self._key_field_source_table_ordinals] + ordered_key_field_values = [table_cols[kfo - 1] for kfo in self.key_field_source_table_ordinals] else: # CAUTION: this strategy for handling PK-less tables means that if columns are added or removed from the # capture instance in the future, the key value computed for the same source table row will change: m = hashlib.md5() - m.update(str(zip(self._value_field_names, table_cols)).encode('utf8')) + m.update(str(zip(self.value_field_names, table_cols)).encode('utf8')) row_hash = str(uuid.uuid5(uuid.UUID(bytes=m.digest()), self.fq_name)) ordered_key_field_values = [row_hash] - key_dict: Dict[str, Any] = dict(zip(self._key_field_names, ordered_key_field_values)) - - return parsed_row.ParsedRow(self.fq_name, row_kind, operation_name, event_db_time, change_idx, - tuple(ordered_key_field_values), self.topic_name, self.key_schema_id, - self.value_schema_id, key_dict, value_dict, extra_headers) - - def _updated_col_names_from_mask(self, cdc_update_mask: bytes) -> List[str]: - arr = bitarray.bitarray() - arr.frombytes(cdc_update_mask) - arr.reverse() - return list(itertools.compress(self._value_field_names, arr)) + return parsed_row.ParsedRow(self.topic_name, operation_id, update_mask, event_db_time, change_idx, + tuple(ordered_key_field_values), table_cols, extra_headers) def _get_max_key_value(self) -> Optional[Tuple[Any, ...]]: with self._db_conn.cursor() as cursor: - q, _ = sql_queries.get_max_key_value(self.schema_name, self.table_name, self._key_field_names) + q, _ = sql_queries.get_max_key_value(self.schema_name, self.table_name, self.key_field_names) cursor.execute(q) row: pyodbc.Row | None = cursor.fetchone() if row: @@ -484,7 +437,7 @@ def _get_max_key_value(self) -> Optional[Tuple[Any, ...]]: def _get_min_key_value(self) -> Optional[Tuple[Any, ...]]: with self._db_conn.cursor() as cursor: - q, _ = sql_queries.get_min_key_value(self.schema_name, self.table_name, self._key_field_names) + q, _ = sql_queries.get_min_key_value(self.schema_name, self.table_name, self.key_field_names) cursor.execute(q) row: pyodbc.Row | None = cursor.fetchone() if row: diff --git a/cdc_kafka/validation.py b/cdc_kafka/validation.py index ac947ea..32d701d 100644 --- a/cdc_kafka/validation.py +++ b/cdc_kafka/validation.py @@ -6,9 +6,8 @@ from typing import List, Iterable, Dict, Tuple, Any, Optional, TYPE_CHECKING, Set, Mapping, Sequence from uuid import UUID -import confluent_kafka - from . import constants, change_index, kafka, helpers +from .serializers import SerializerAbstract, DeserializedMessage if TYPE_CHECKING: from . import tracked_tables, progress_tracking @@ -42,7 +41,7 @@ def __repr__(self) -> str: return str(self.uuid).upper() -def extract_key_tuple(table: 'tracked_tables.TrackedTable', message: Mapping[str, str | int]) -> Tuple[Any, ...]: +def extract_key_tuple(table: 'tracked_tables.TrackedTable', message: Mapping[str, Any]) -> Tuple[Any, ...]: key_bits: List[Any] = [] for kf in table.key_fields: if kf.sql_type_name == 'uniqueidentifier': @@ -107,25 +106,37 @@ def __repr__(self) -> str: 'missing_offsets': self.missing_offsets, }) - def process_message(self, message: confluent_kafka.Message) -> None: + def process_message(self, message: DeserializedMessage) -> None: self.total_count += 1 + raw_partition = message.raw_msg.partition() + if raw_partition is None: + raise Exception('Unexpected state: None value for message partition.') + else: + partition: int = raw_partition + + raw_offset = message.raw_msg.offset() + if raw_offset is None: + raise Exception('Unexpected state: None value for message offset.') + else: + offset: int = raw_offset - if message.partition() not in self._last_processed_offset_by_partition: - self._last_processed_offset_by_partition[message.partition()] = -1 + if partition not in self._last_processed_offset_by_partition: + self._last_processed_offset_by_partition[partition] = -1 - self.missing_offsets += (message.offset() - self._last_processed_offset_by_partition[message.partition()] - 1) - self._last_processed_offset_by_partition[message.partition()] = message.offset() + self.missing_offsets += (offset - self._last_processed_offset_by_partition[partition] - 1) + self._last_processed_offset_by_partition[partition] = offset # noinspection PyArgumentList - if message.value() is None: + if not message.raw_msg.value(): self.tombstone_count += 1 return - # noinspection PyTypeChecker,PyArgumentList - message_body = dict(message.value()) - key = extract_key_tuple(self.table, message_body) - coordinates = helpers.format_coordinates(message) - operation_name = message_body[constants.OPERATION_NAME] + if message.value_dict is None: + raise Exception('Unexpected state') + + key = extract_key_tuple(self.table, message.value_dict) + coordinates = helpers.format_coordinates(message.raw_msg) + operation_name = message.value_dict[constants.OPERATION_NAME] if operation_name == constants.SNAPSHOT_OPERATION_NAME: self.snapshot_count += 1 @@ -134,19 +145,19 @@ def process_message(self, message: confluent_kafka.Message) -> None: self.min_snapshot_key_seen = key if self.max_snapshot_key_seen is None or key > self.max_snapshot_key_seen: self.max_snapshot_key_seen = key - if message.partition() in self._last_snapshot_key_seen_for_partition and \ - self._last_snapshot_key_seen_for_partition[message.partition()] < key: + if partition in self._last_snapshot_key_seen_for_partition and \ + self._last_snapshot_key_seen_for_partition[partition] < key: self.snapshot_key_order_regressions_count += 1 logger.debug( "Snapshot key order regression for %s: value %s at coordinates %s --> value %s at coordinates %s", self.table.fq_name, - self._last_snapshot_key_seen_for_partition[message.partition()], - self._last_snapshot_coordinates_seen_for_partition[message.partition()], + self._last_snapshot_key_seen_for_partition[partition], + self._last_snapshot_coordinates_seen_for_partition[partition], key, coordinates ) - self._last_snapshot_coordinates_seen_for_partition[message.partition()] = coordinates - self._last_snapshot_key_seen_for_partition[message.partition()] = key + self._last_snapshot_coordinates_seen_for_partition[partition] = coordinates + self._last_snapshot_key_seen_for_partition[partition] = key return if operation_name == constants.DELETE_OPERATION_NAME: @@ -155,7 +166,7 @@ def process_message(self, message: confluent_kafka.Message) -> None: self.keys_seen_in_changes.add(key) - msg_change_index = change_index.ChangeIndex.from_avro_ready_dict(message_body) + msg_change_index = change_index.ChangeIndex.from_dict(message.value_dict) if msg_change_index.lsn < self.table.min_lsn: # the live change table has been truncated and no longer has this entry return @@ -170,17 +181,17 @@ def process_message(self, message: confluent_kafka.Message) -> None: self.unknown_operation_count += 1 return - change_idx = change_index.ChangeIndex.from_avro_ready_dict(message_body) + change_idx = change_index.ChangeIndex.from_dict(message.value_dict) if self.min_change_index_seen is None or change_idx < self.min_change_index_seen: self.min_change_index_seen = change_idx if self.max_change_index_seen is None or change_idx > self.max_change_index_seen: self.max_change_index_seen = change_idx - self.max_change_index_seen_coordinates = helpers.format_coordinates(message) - if message.partition() in self._last_change_index_seen_for_partition and \ - self._last_change_index_seen_for_partition[message.partition()] > change_idx: + self.max_change_index_seen_coordinates = helpers.format_coordinates(message.raw_msg) + if partition in self._last_change_index_seen_for_partition and \ + self._last_change_index_seen_for_partition[partition] > change_idx: self.change_index_order_regressions_count += 1 - self._last_change_index_seen_for_partition[message.partition()] = change_idx - event_time = datetime.datetime.fromisoformat(message_body[constants.EVENT_TIME_NAME]) + self._last_change_index_seen_for_partition[partition] = change_idx + event_time = datetime.datetime.fromisoformat(message.value_dict[constants.EVENT_TIME_NAME]) if self.latest_change_seen is None or event_time > self.latest_change_seen: self.latest_change_seen = event_time return @@ -188,11 +199,12 @@ def process_message(self, message: confluent_kafka.Message) -> None: class Validator(object): def __init__(self, kafka_client: 'kafka.KafkaClient', tables: Iterable['tracked_tables.TrackedTable'], - progress_tracker: 'progress_tracking.ProgressTracker', + progress_tracker: 'progress_tracking.ProgressTracker', serializer: SerializerAbstract, unified_topic_to_tables_map: Dict[str, List['tracked_tables.TrackedTable']]) -> None: self._kafka_client: 'kafka.KafkaClient' = kafka_client self._tables_by_name: Dict[str, 'tracked_tables.TrackedTable'] = {t.fq_name: t for t in tables} self._progress_tracker: 'progress_tracking.ProgressTracker' = progress_tracker + self._serializer: SerializerAbstract = serializer self._unified_topic_to_tables_map: Dict[str, List['tracked_tables.TrackedTable']] = unified_topic_to_tables_map def run(self) -> None: @@ -214,13 +226,13 @@ def run(self) -> None: unified_topic_name, unified_topic_tables, watermarks_by_topic[unified_topic_name]) total_tables = len(self._tables_by_name) - for table_name, table in self._tables_by_name.items(): - logger.info('Processing table %s (%d/%d)', table_name, len(summaries_by_single_table) + 1, total_tables) - summaries_by_single_table[table_name] = self._process_single_table_topic( + for source_topic_name, table in self._tables_by_name.items(): + logger.info('Processing table %s (%d/%d)', source_topic_name, len(summaries_by_single_table) + 1, total_tables) + summaries_by_single_table[source_topic_name] = self._process_single_table_topic( table, watermarks_by_topic[table.topic_name]) - for table_name, summary in summaries_by_single_table.items(): - table = self._tables_by_name[table_name] + for source_topic_name, summary in summaries_by_single_table.items(): + table = self._tables_by_name[source_topic_name] failures, warnings, infos = [], [], [] progress_entry = progress.get((table.topic_name, constants.SNAPSHOT_ROWS_KIND)) @@ -331,16 +343,16 @@ def run(self) -> None: failures.append(f'Found {db_update_rows} update entries in DB change table but {summary.update_count} ' f'in Kafka topic') - print(f'\nSummary for table {table_name} in single-table topic {table.topic_name}:') + print(f'\nSummary for table {source_topic_name} in single-table topic {table.topic_name}:') for info in infos: - print(f' INFO: {info} ({table_name})') + print(f' INFO: {info} ({source_topic_name})') if not (warnings or failures): - print(f' OK: No problems! ({table_name})') + print(f' OK: No problems! ({source_topic_name})') else: for warning in warnings: - print(f' WARN: {warning} ({table_name})') + print(f' WARN: {warning} ({source_topic_name})') for failure in failures: - print(f' FAIL: {failure} ({table_name})') + print(f' FAIL: {failure} ({source_topic_name})') if failures: db_data = json.dumps({ @@ -358,24 +370,24 @@ def run(self) -> None: warnings = ut_result['warnings'] failures = ut_result['failures'] - for table_name, table_summary in ut_result['table_summaries'].items(): + for source_topic_name, table_summary in ut_result['table_summaries'].items(): if table_summary.latest_change_seen is None: - warnings.append(f'For table {table_name}: No change entries found!') + warnings.append(f'For source topic {source_topic_name}: No change entries found!') elif (helpers.naive_utcnow() - table_summary.latest_change_seen) > datetime.timedelta(days=1): - warnings.append(f'For table {table_name}: Last change entry seen in Kafka was dated ' + warnings.append(f'For source topic {source_topic_name}: Last change entry seen in Kafka was dated ' f'{table_summary.latest_change_seen}.') if table_summary.change_index_order_regressions_count: failures.append( - f'For table {table_name}: Kafka topic contained ' + f'For source topic {source_topic_name}: Kafka topic contained ' f'{table_summary.change_index_order_regressions_count} regressions in change index ordering.') if table_summary.unknown_operation_count: failures.append( - f'For table {table_name}: Topic contained {table_summary.unknown_operation_count} messages ' - f'with an unknown operation type.') + f'For source topic {source_topic_name}: Topic contained {table_summary.unknown_operation_count} ' + f'messages with an unknown operation type.') if table_summary.snapshot_count: failures.append( - f'For table {table_name}: Topic contained {table_summary.snapshot_count} unexpected snapshot ' - f'records.') + f'For source topic {source_topic_name}: Topic contained {table_summary.snapshot_count} ' + f'unexpected snapshot records.') for warning in warnings: print(f' WARN: {warning}') @@ -389,7 +401,8 @@ def _process_single_table_topic(self, table: 'tracked_tables.TrackedTable', for msg in self._kafka_client.consume_bounded( table.topic_name, constants.VALIDATION_MAXIMUM_SAMPLE_SIZE_PER_TOPIC, captured_watermarks): msg_count += 1 - table_summary.process_message(msg) + deser_msg = self._serializer.deserialize(msg) + table_summary.process_message(deser_msg) logger.info('Validation: consumed %s records from topic %s', msg_count, table.topic_name) return table_summary @@ -397,7 +410,7 @@ def _process_unified_topic(self, topic_name: str, expected_tables: Iterable['tra captured_watermarks: List[Tuple[int, int]]) -> Dict[str, Any]: logger.info('Validation: consuming records from unified topic %s', topic_name) - table_summaries: Dict[str, TableMessagesSummary] = {t.fq_name: TableMessagesSummary(t) for t in expected_tables} + table_summaries: Dict[str, TableMessagesSummary] = {t.topic_name: TableMessagesSummary(t) for t in expected_tables} warnings: List[str] = [] failures: List[str] = [] sample_regression_indices: List[str] = [] @@ -412,16 +425,27 @@ def _process_unified_topic(self, topic_name: str, expected_tables: Iterable['tra for msg in self._kafka_client.consume_bounded( topic_name, constants.VALIDATION_MAXIMUM_SAMPLE_SIZE_PER_TOPIC, captured_watermarks): total_messages_read += 1 + deser_msg = self._serializer.deserialize(msg) # noinspection PyArgumentList - if msg.value() is None: + if not msg.value(): tombstones_count += 1 continue # noinspection PyTypeChecker,PyArgumentList - msg_val = dict(msg.value()) - msg_table = msg_val.pop('__avro_schema_name').replace('_cdc__value', '').replace('_', '.', 1) - msg_change_index = change_index.ChangeIndex.from_avro_ready_dict(msg_val) + raw_headers = msg.headers() + if raw_headers is None: + raise Exception('Unexpected state: Headers missing from unified topic message') + msg_table_topic_raw = dict(raw_headers)['cdc_to_kafka_original_topic'] + if type(msg_table_topic_raw) is str: + msg_table_topic = msg_table_topic_raw + elif type(msg_table_topic_raw) is bytes: + msg_table_topic = msg_table_topic_raw.decode('utf-8') + else: + raise Exception("Unexpected data type in headers") + if deser_msg.value_dict is None: + raise Exception('Unexpected state: Missing value_dict from unified topic message') + msg_change_index = change_index.ChangeIndex.from_dict(deser_msg.value_dict) if prior_read_change_index is not None and prior_read_change_index > msg_change_index: if len(sample_regression_indices) < 10: @@ -431,13 +455,13 @@ def _process_unified_topic(self, topic_name: str, expected_tables: Iterable['tra lsn_regressions_count += 1 prior_read_change_index = msg_change_index - prior_read_partition = msg.partition() - prior_read_offset = msg.offset() + prior_read_partition = msg.partition() or prior_read_partition + prior_read_offset = msg.offset() or prior_read_offset - if msg_table in table_summaries: - table_summaries[msg_table].process_message(msg) + if msg_table_topic in table_summaries: + table_summaries[msg_table_topic].process_message(deser_msg) else: - unexpected_table_msg_counts[msg_table] += 1 + unexpected_table_msg_counts[msg_table_topic] += 1 if lsn_regressions_count: failures.append(f'{lsn_regressions_count} LSN ordering regressions encountered, with examples ' @@ -447,7 +471,7 @@ def _process_unified_topic(self, topic_name: str, expected_tables: Iterable['tra failures.append(f'{tombstones_count} unexpected deletion tombstones encountered') if unexpected_table_msg_counts: - warnings.append(f'Topic contained messages from unanticipated source tables: ' + warnings.append(f'Topic contained messages from unanticipated source topics: ' f'{json.dumps(unexpected_table_msg_counts)}') return { diff --git a/requirements.txt b/requirements.txt index c84156f..d619a9d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ avro==1.11.3 aws-msk-iam-sasl-signer-python==1.0.1 bitarray==2.9.2 confluent-kafka==2.3.0 -fastavro==1.9.4 Jinja2==3.1.3 pyodbc==5.1.0 requests==2.31.0 @@ -12,3 +11,14 @@ tabulate==0.9.0 # Only used by replayer.py; if using the Docker image, this requires that you apt install freetds-dev and python3-dev too: # ctds==1.14.0 +# faster-fifo==1.4.5 + +# Helpers if you're doing development on this project: +# ipython==8.23.0 +# line-profiler==4.1.2 +# mypy==1.10.0 +# sortedcontainers-stubs==2.4.2 +# types-confluent-kafka==1.2.0 +# types-requests==2.31.0.20240406 +# types-tabulate==0.9.0.20240106 +# vulture==2.11