diff --git a/cdc_kafka/kafka.py b/cdc_kafka/kafka.py index c6e3172..74f00cd 100644 --- a/cdc_kafka/kafka.py +++ b/cdc_kafka/kafka.py @@ -176,7 +176,8 @@ def commit_transaction(self) -> None: self._producer.commit_transaction() def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int, value: Optional[Dict[str, Any]], - value_schema_id: int, message_type: str, copy_to_unified_topics: Optional[List[str]] = None) -> None: + value_schema_id: int, message_type: str, copy_to_unified_topics: Optional[List[str]] = None, + extra_headers: Optional[Dict[str, str | bytes]] = None) -> None: if self._disable_writing: return @@ -199,7 +200,7 @@ def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int, self._producer.produce( topic=topic, value=value_ser, key=key_ser, on_delivery=lambda err, msg: self._delivery_callback(message_type, err, msg, key, value), - headers={'cdc_to_kafka_message_type': message_type} + headers={'cdc_to_kafka_message_type': message_type, **(extra_headers or {})} ) break except BufferError: @@ -223,7 +224,7 @@ def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int, on_delivery=lambda err, msg: self._delivery_callback( constants.UNIFIED_TOPIC_CHANGE_MESSAGE, err, msg, key, value), headers={'cdc_to_kafka_message_type': constants.UNIFIED_TOPIC_CHANGE_MESSAGE, - 'cdc_to_kafka_original_topic': topic} + 'cdc_to_kafka_original_topic': topic, **(extra_headers or {})} ) break except BufferError: diff --git a/cdc_kafka/main.py b/cdc_kafka/main.py index 82c53c8..4c5ed26 100644 --- a/cdc_kafka/main.py +++ b/cdc_kafka/main.py @@ -226,7 +226,8 @@ def poll_periodic_tasks() -> bool: kafka_client.produce(row.destination_topic, row.key_dict, row.avro_key_schema_id, row.value_dict, row.avro_value_schema_id, - constants.SINGLE_TABLE_SNAPSHOT_MESSAGE) + constants.SINGLE_TABLE_SNAPSHOT_MESSAGE, + extra_headers=row.extra_headers) snapshot_progress_by_topic[row.destination_topic] = row.key_dict if t.snapshot_complete: progress_tracker.record_snapshot_progress( @@ -332,7 +333,8 @@ def poll_periodic_tasks() -> bool: kafka_client.produce(row.destination_topic, row.key_dict, row.avro_key_schema_id, row.value_dict, row.avro_value_schema_id, constants.SINGLE_TABLE_CHANGE_MESSAGE, - table_to_unified_topics_map.get(row.table_fq_name, [])) + table_to_unified_topics_map.get(row.table_fq_name, []), + extra_headers=row.extra_headers) last_topic_produces[row.destination_topic] = helpers.naive_utcnow() if not opts.disable_deletion_tombstones and row.operation_name == \ diff --git a/cdc_kafka/parsed_row.py b/cdc_kafka/parsed_row.py index e0cc394..cb80a83 100644 --- a/cdc_kafka/parsed_row.py +++ b/cdc_kafka/parsed_row.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, Dict, Sequence +from typing import Any, Dict, Sequence, Optional from . import change_index @@ -7,12 +7,12 @@ class ParsedRow(object): __slots__ = 'table_fq_name', 'row_kind', 'operation_name', 'event_db_time', 'change_idx', \ 'ordered_key_field_values', 'destination_topic', 'avro_key_schema_id', 'avro_value_schema_id', 'key_dict', \ - 'value_dict' + 'value_dict', 'extra_headers' def __init__(self, table_fq_name: str, row_kind: str, operation_name: str, event_db_time: datetime.datetime, change_idx: change_index.ChangeIndex, ordered_key_field_values: Sequence[Any], destination_topic: str, avro_key_schema_id: int, avro_value_schema_id: int, key_dict: Dict[str, Any], - value_dict: Dict[str, Any]) -> None: + value_dict: Dict[str, Any], extra_headers: Optional[Dict[str, str | bytes]] = None) -> None: self.table_fq_name: str = table_fq_name self.row_kind: str = row_kind self.operation_name: str = operation_name @@ -24,6 +24,7 @@ def __init__(self, table_fq_name: str, row_kind: str, operation_name: str, event self.avro_value_schema_id: int = avro_value_schema_id self.key_dict: Dict[str, Any] = key_dict self.value_dict: Dict[str, Any] = value_dict + self.extra_headers: Optional[Dict[str, str | bytes]] = extra_headers def __repr__(self) -> str: return f'ParsedRow from {self.table_fq_name} of kind {self.row_kind}, change index {self.change_idx}' diff --git a/cdc_kafka/tracked_tables.py b/cdc_kafka/tracked_tables.py index 8efc56b..8b1f2b5 100644 --- a/cdc_kafka/tracked_tables.py +++ b/cdc_kafka/tracked_tables.py @@ -20,7 +20,7 @@ class TrackedField(object): __slots__ = 'name', 'sql_type_name', 'change_table_ordinal', 'primary_key_ordinal', 'decimal_precision', \ - 'decimal_scale', 'transform_fn' + 'decimal_scale', 'transform_fn', 'truncate_after' def __init__(self, name: str, sql_type_name: str, change_table_ordinal: int, primary_key_ordinal: int, decimal_precision: int, decimal_scale: int, truncate_after: int = 0) -> None: @@ -32,17 +32,10 @@ def __init__(self, name: str, sql_type_name: str, change_table_ordinal: int, pri self.decimal_scale: int = decimal_scale self.transform_fn: Optional[Callable[[Any], Any]] = avro_transform_fn_from_sql_type(sql_type_name) - if truncate_after: - if self.sql_type_name not in constants.SQL_STRING_TYPES: - raise Exception(f'A truncation length was specified for field {name} but it does not appear to be a ' - f'string field (SQL type is {sql_type_name}).') - - if self.transform_fn: - orig_transform = self.transform_fn - # TODO: this prevents orig_transform from ever receiving a None argument; is that okay?? - self.transform_fn = lambda x: orig_transform(x)[:int(truncate_after)] if x is not None else x - else: - self.transform_fn = lambda x: x[:int(truncate_after)] if x is not None else x + if truncate_after and self.sql_type_name not in constants.SQL_STRING_TYPES: + raise Exception(f'A truncation length was specified for field {name} but it does not appear to be a ' + f'string field (SQL type is {sql_type_name}).') + self.truncate_after: int = truncate_after class TrackedTable(object): @@ -383,12 +376,61 @@ def get_change_rows_per_second(self) -> int: cursor.execute(q) return cursor.fetchval() or 0 + @staticmethod + def cut_str_to_bytes(s: str, max_bytes: int) -> Tuple[int, str]: + # Mostly copied from https://github.com/halloleo/unicut/blob/master/truncate.py + def safe_b_of_i(b, i): + try: + return b[i] + except IndexError: + return 0 + + if s == '' or max_bytes < 1: + return 0, '' + + b = s[:max_bytes].encode('utf-8')[:max_bytes] + + if b[-1] & 0b10000000: + last_11x_index = [ + i + for i in range(-1, -5, -1) + if safe_b_of_i(b, i) & 0b11000000 == 0b11000000 + ][0] + + last_11x = b[last_11x_index] + last_char_length = 1 + if not last_11x & 0b00100000: + last_char_length = 2 + elif not last_11x & 0b0010000: + last_char_length = 3 + elif not last_11x & 0b0001000: + last_char_length = 4 + + if last_char_length > -last_11x_index: + # remove the incomplete character + b = b[:last_11x_index] + + return len(b), b.decode('utf-8') + def _parse_db_row(self, db_row: pyodbc.Row) -> parsed_row.ParsedRow: operation_id, event_db_time, lsn, seqval, update_mask, *table_cols = db_row operation_name = constants.CDC_OPERATION_ID_TO_NAME[operation_id] - value_dict = {fld.name: fld.transform_fn(table_cols[ix]) if fld.transform_fn else table_cols[ix] - for ix, fld in enumerate(self.value_fields)} + value_dict = {} + extra_headers = {} + for ix, fld in enumerate(self.value_fields): + val = table_cols[ix] + if fld.transform_fn: + val = fld.transform_fn(val) + # The '* 4' below is because that's the maximum possible byte length of a UTF-8 encoded character--just + # trying to optimize away the extra code inside the `if` block when possible: + if val and fld.truncate_after and len(val) * 4 > fld.truncate_after: + original_encoded_length = len(val.encode('utf-8')) + if original_encoded_length > fld.truncate_after: + new_encoded_length, val = TrackedTable.cut_str_to_bytes(val, fld.truncate_after) + extra_headers[f'cdc_to_kafka_truncated_field__{fld.name}'] = \ + f'{original_encoded_length},{new_encoded_length}' + value_dict[fld.name] = val if operation_id == constants.SNAPSHOT_OPERATION_ID: row_kind = constants.SNAPSHOT_ROWS_KIND @@ -423,7 +465,7 @@ def _parse_db_row(self, db_row: pyodbc.Row) -> parsed_row.ParsedRow: return parsed_row.ParsedRow(self.fq_name, row_kind, operation_name, event_db_time, change_idx, tuple(ordered_key_field_values), self.topic_name, self.key_schema_id, - self.value_schema_id, key_dict, value_dict) + self.value_schema_id, key_dict, value_dict, extra_headers) def _updated_col_names_from_mask(self, cdc_update_mask: bytes) -> List[str]: arr = bitarray.bitarray()