Skip to content

Commit

Permalink
Add Kafka message header to indicate when fields have been truncated (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
woodlee authored Jun 5, 2024
1 parent bcc15d7 commit f24149f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 23 deletions.
7 changes: 4 additions & 3 deletions cdc_kafka/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def commit_transaction(self) -> None:
self._producer.commit_transaction()

def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int, value: Optional[Dict[str, Any]],
value_schema_id: int, message_type: str, copy_to_unified_topics: Optional[List[str]] = None) -> None:
value_schema_id: int, message_type: str, copy_to_unified_topics: Optional[List[str]] = None,
extra_headers: Optional[Dict[str, str | bytes]] = None) -> None:
if self._disable_writing:
return

Expand All @@ -199,7 +200,7 @@ def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int,
self._producer.produce(
topic=topic, value=value_ser, key=key_ser,
on_delivery=lambda err, msg: self._delivery_callback(message_type, err, msg, key, value),
headers={'cdc_to_kafka_message_type': message_type}
headers={'cdc_to_kafka_message_type': message_type, **(extra_headers or {})}
)
break
except BufferError:
Expand All @@ -223,7 +224,7 @@ def produce(self, topic: str, key: Optional[Dict[str, Any]], key_schema_id: int,
on_delivery=lambda err, msg: self._delivery_callback(
constants.UNIFIED_TOPIC_CHANGE_MESSAGE, err, msg, key, value),
headers={'cdc_to_kafka_message_type': constants.UNIFIED_TOPIC_CHANGE_MESSAGE,
'cdc_to_kafka_original_topic': topic}
'cdc_to_kafka_original_topic': topic, **(extra_headers or {})}
)
break
except BufferError:
Expand Down
6 changes: 4 additions & 2 deletions cdc_kafka/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def poll_periodic_tasks() -> bool:
kafka_client.produce(row.destination_topic, row.key_dict,
row.avro_key_schema_id, row.value_dict,
row.avro_value_schema_id,
constants.SINGLE_TABLE_SNAPSHOT_MESSAGE)
constants.SINGLE_TABLE_SNAPSHOT_MESSAGE,
extra_headers=row.extra_headers)
snapshot_progress_by_topic[row.destination_topic] = row.key_dict
if t.snapshot_complete:
progress_tracker.record_snapshot_progress(
Expand Down Expand Up @@ -332,7 +333,8 @@ def poll_periodic_tasks() -> bool:
kafka_client.produce(row.destination_topic, row.key_dict, row.avro_key_schema_id,
row.value_dict, row.avro_value_schema_id,
constants.SINGLE_TABLE_CHANGE_MESSAGE,
table_to_unified_topics_map.get(row.table_fq_name, []))
table_to_unified_topics_map.get(row.table_fq_name, []),
extra_headers=row.extra_headers)
last_topic_produces[row.destination_topic] = helpers.naive_utcnow()

if not opts.disable_deletion_tombstones and row.operation_name == \
Expand Down
7 changes: 4 additions & 3 deletions cdc_kafka/parsed_row.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import datetime
from typing import Any, Dict, Sequence
from typing import Any, Dict, Sequence, Optional

from . import change_index


class ParsedRow(object):
__slots__ = 'table_fq_name', 'row_kind', 'operation_name', 'event_db_time', 'change_idx', \
'ordered_key_field_values', 'destination_topic', 'avro_key_schema_id', 'avro_value_schema_id', 'key_dict', \
'value_dict'
'value_dict', 'extra_headers'

def __init__(self, table_fq_name: str, row_kind: str, operation_name: str, event_db_time: datetime.datetime,
change_idx: change_index.ChangeIndex, ordered_key_field_values: Sequence[Any], destination_topic: str,
avro_key_schema_id: int, avro_value_schema_id: int, key_dict: Dict[str, Any],
value_dict: Dict[str, Any]) -> None:
value_dict: Dict[str, Any], extra_headers: Optional[Dict[str, str | bytes]] = None) -> None:
self.table_fq_name: str = table_fq_name
self.row_kind: str = row_kind
self.operation_name: str = operation_name
Expand All @@ -24,6 +24,7 @@ def __init__(self, table_fq_name: str, row_kind: str, operation_name: str, event
self.avro_value_schema_id: int = avro_value_schema_id
self.key_dict: Dict[str, Any] = key_dict
self.value_dict: Dict[str, Any] = value_dict
self.extra_headers: Optional[Dict[str, str | bytes]] = extra_headers

def __repr__(self) -> str:
return f'ParsedRow from {self.table_fq_name} of kind {self.row_kind}, change index {self.change_idx}'
72 changes: 57 additions & 15 deletions cdc_kafka/tracked_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class TrackedField(object):
__slots__ = 'name', 'sql_type_name', 'change_table_ordinal', 'primary_key_ordinal', 'decimal_precision', \
'decimal_scale', 'transform_fn'
'decimal_scale', 'transform_fn', 'truncate_after'

def __init__(self, name: str, sql_type_name: str, change_table_ordinal: int, primary_key_ordinal: int,
decimal_precision: int, decimal_scale: int, truncate_after: int = 0) -> None:
Expand All @@ -32,17 +32,10 @@ def __init__(self, name: str, sql_type_name: str, change_table_ordinal: int, pri
self.decimal_scale: int = decimal_scale
self.transform_fn: Optional[Callable[[Any], Any]] = avro_transform_fn_from_sql_type(sql_type_name)

if truncate_after:
if self.sql_type_name not in constants.SQL_STRING_TYPES:
raise Exception(f'A truncation length was specified for field {name} but it does not appear to be a '
f'string field (SQL type is {sql_type_name}).')

if self.transform_fn:
orig_transform = self.transform_fn
# TODO: this prevents orig_transform from ever receiving a None argument; is that okay??
self.transform_fn = lambda x: orig_transform(x)[:int(truncate_after)] if x is not None else x
else:
self.transform_fn = lambda x: x[:int(truncate_after)] if x is not None else x
if truncate_after and self.sql_type_name not in constants.SQL_STRING_TYPES:
raise Exception(f'A truncation length was specified for field {name} but it does not appear to be a '
f'string field (SQL type is {sql_type_name}).')
self.truncate_after: int = truncate_after


class TrackedTable(object):
Expand Down Expand Up @@ -383,12 +376,61 @@ def get_change_rows_per_second(self) -> int:
cursor.execute(q)
return cursor.fetchval() or 0

@staticmethod
def cut_str_to_bytes(s: str, max_bytes: int) -> Tuple[int, str]:
# Mostly copied from https://github.com/halloleo/unicut/blob/master/truncate.py
def safe_b_of_i(b, i):
try:
return b[i]
except IndexError:
return 0

if s == '' or max_bytes < 1:
return 0, ''

b = s[:max_bytes].encode('utf-8')[:max_bytes]

if b[-1] & 0b10000000:
last_11x_index = [
i
for i in range(-1, -5, -1)
if safe_b_of_i(b, i) & 0b11000000 == 0b11000000
][0]

last_11x = b[last_11x_index]
last_char_length = 1
if not last_11x & 0b00100000:
last_char_length = 2
elif not last_11x & 0b0010000:
last_char_length = 3
elif not last_11x & 0b0001000:
last_char_length = 4

if last_char_length > -last_11x_index:
# remove the incomplete character
b = b[:last_11x_index]

return len(b), b.decode('utf-8')

def _parse_db_row(self, db_row: pyodbc.Row) -> parsed_row.ParsedRow:
operation_id, event_db_time, lsn, seqval, update_mask, *table_cols = db_row
operation_name = constants.CDC_OPERATION_ID_TO_NAME[operation_id]

value_dict = {fld.name: fld.transform_fn(table_cols[ix]) if fld.transform_fn else table_cols[ix]
for ix, fld in enumerate(self.value_fields)}
value_dict = {}
extra_headers = {}
for ix, fld in enumerate(self.value_fields):
val = table_cols[ix]
if fld.transform_fn:
val = fld.transform_fn(val)
# The '* 4' below is because that's the maximum possible byte length of a UTF-8 encoded character--just
# trying to optimize away the extra code inside the `if` block when possible:
if val and fld.truncate_after and len(val) * 4 > fld.truncate_after:
original_encoded_length = len(val.encode('utf-8'))
if original_encoded_length > fld.truncate_after:
new_encoded_length, val = TrackedTable.cut_str_to_bytes(val, fld.truncate_after)
extra_headers[f'cdc_to_kafka_truncated_field__{fld.name}'] = \
f'{original_encoded_length},{new_encoded_length}'
value_dict[fld.name] = val

if operation_id == constants.SNAPSHOT_OPERATION_ID:
row_kind = constants.SNAPSHOT_ROWS_KIND
Expand Down Expand Up @@ -423,7 +465,7 @@ def _parse_db_row(self, db_row: pyodbc.Row) -> parsed_row.ParsedRow:

return parsed_row.ParsedRow(self.fq_name, row_kind, operation_name, event_db_time, change_idx,
tuple(ordered_key_field_values), self.topic_name, self.key_schema_id,
self.value_schema_id, key_dict, value_dict)
self.value_schema_id, key_dict, value_dict, extra_headers)

def _updated_col_names_from_mask(self, cdc_update_mask: bytes) -> List[str]:
arr = bitarray.bitarray()
Expand Down

0 comments on commit f24149f

Please sign in to comment.